sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
- sglang/lang/chat_template.py +17 -0
- sglang/launch_server_llavavid.py +1 -1
- sglang/srt/configs/__init__.py +3 -0
- sglang/srt/configs/model_config.py +27 -2
- sglang/srt/configs/qwen2vl.py +133 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/conversation.py +27 -0
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/__init__.py +16 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
- sglang/srt/layers/attention/flashinfer_backend.py +174 -54
- sglang/srt/layers/attention/triton_backend.py +22 -6
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
- sglang/srt/layers/linear.py +89 -63
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/rotary_embedding.py +112 -0
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/lora/lora.py +3 -1
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +4 -0
- sglang/srt/managers/image_processor.py +186 -13
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/schedule_batch.py +238 -68
- sglang/srt/managers/scheduler.py +69 -50
- sglang/srt/managers/tokenizer_manager.py +24 -4
- sglang/srt/managers/tp_worker.py +26 -111
- sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
- sglang/srt/mem_cache/memory_pool.py +56 -10
- sglang/srt/mem_cache/radix_cache.py +4 -3
- sglang/srt/model_executor/cuda_graph_runner.py +87 -28
- sglang/srt/model_executor/forward_batch_info.py +83 -3
- sglang/srt/model_executor/model_runner.py +32 -11
- sglang/srt/models/chatglm.py +3 -3
- sglang/srt/models/deepseek_v2.py +2 -2
- sglang/srt/models/mllama.py +1004 -0
- sglang/srt/models/qwen2_vl.py +724 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +13 -3
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +12 -0
- sglang/srt/server_args.py +10 -0
- sglang/srt/utils.py +22 -0
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +20 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +100 -3
- sglang/version.py +1 -1
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,724 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Adapted from
|
3
|
+
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
|
4
|
+
# Copyright 2024 The Qwen team.
|
5
|
+
# Copyright 2023 The vLLM team.
|
6
|
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
7
|
+
#
|
8
|
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
9
|
+
# and OPT implementations in this library. It has been modified from its
|
10
|
+
# original forms to accommodate minor architectural differences compared
|
11
|
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
12
|
+
#
|
13
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
14
|
+
# you may not use this file except in compliance with the License.
|
15
|
+
# You may obtain a copy of the License at
|
16
|
+
#
|
17
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
18
|
+
#
|
19
|
+
# Unless required by applicable law or agreed to in writing, software
|
20
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
21
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
22
|
+
# See the License for the specific language governing permissions and
|
23
|
+
# limitations under the License.
|
24
|
+
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
25
|
+
from functools import lru_cache, partial
|
26
|
+
from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union
|
27
|
+
|
28
|
+
import numpy as np
|
29
|
+
import torch
|
30
|
+
import torch.nn as nn
|
31
|
+
import torch.nn.functional as F
|
32
|
+
from einops import rearrange, repeat
|
33
|
+
from vllm.config import CacheConfig, MultiModalConfig
|
34
|
+
from vllm.distributed import parallel_state
|
35
|
+
from vllm.distributed import utils as dist_utils
|
36
|
+
from vllm.logger import init_logger
|
37
|
+
from vllm.model_executor.layers.activation import QuickGELU
|
38
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
39
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
|
+
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
41
|
+
|
42
|
+
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
43
|
+
from sglang.srt.hf_transformers_utils import get_processor
|
44
|
+
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
45
|
+
context_attention_fwd,
|
46
|
+
)
|
47
|
+
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
48
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
49
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
50
|
+
from sglang.srt.managers.schedule_batch import ImageInputs
|
51
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
52
|
+
from sglang.srt.models.qwen2 import Qwen2Model
|
53
|
+
|
54
|
+
logger = init_logger(__name__)
|
55
|
+
|
56
|
+
# === Vision Inputs === #
|
57
|
+
|
58
|
+
|
59
|
+
class Qwen2VLImageInputs(TypedDict):
|
60
|
+
pixel_values: torch.Tensor
|
61
|
+
"""Shape:
|
62
|
+
`(num_patches, num_channels * patch_size * patch_size)`
|
63
|
+
"""
|
64
|
+
|
65
|
+
image_grid_thw: torch.Tensor
|
66
|
+
"""Shape: `(num_images, 3)`
|
67
|
+
|
68
|
+
This should be in `(grid_t, grid_h, grid_w)` format.
|
69
|
+
"""
|
70
|
+
|
71
|
+
|
72
|
+
class Qwen2VLVideoInputs(TypedDict):
|
73
|
+
pixel_values_videos: torch.Tensor
|
74
|
+
"""Shape:
|
75
|
+
`(num_patches,
|
76
|
+
num_channels * temporal_patch_size * patch_size * patch_size)`
|
77
|
+
"""
|
78
|
+
|
79
|
+
video_grid_thw: torch.Tensor
|
80
|
+
"""Shape: `(num_videos, 3)`
|
81
|
+
|
82
|
+
This should be in `(grid_t, grid_h, grid_w)` format.
|
83
|
+
"""
|
84
|
+
|
85
|
+
|
86
|
+
# === Vision Encoder === #
|
87
|
+
|
88
|
+
|
89
|
+
class Qwen2VisionMLP(nn.Module):
|
90
|
+
|
91
|
+
def __init__(
|
92
|
+
self,
|
93
|
+
in_features: int,
|
94
|
+
hidden_features: int = None,
|
95
|
+
act_layer: Type[nn.Module] = QuickGELU,
|
96
|
+
quant_config: Optional[QuantizationConfig] = None,
|
97
|
+
):
|
98
|
+
super().__init__()
|
99
|
+
self.fc1 = ColumnParallelLinear(
|
100
|
+
in_features, hidden_features, quant_config=quant_config
|
101
|
+
)
|
102
|
+
self.act = act_layer()
|
103
|
+
self.fc2 = RowParallelLinear(
|
104
|
+
hidden_features, in_features, quant_config=quant_config
|
105
|
+
)
|
106
|
+
|
107
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
108
|
+
x_parallel, _ = self.fc1(x)
|
109
|
+
x_parallel = self.act(x_parallel)
|
110
|
+
x, _ = self.fc2(x_parallel)
|
111
|
+
return x
|
112
|
+
|
113
|
+
|
114
|
+
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
115
|
+
if not interleaved:
|
116
|
+
x1, x2 = x.chunk(2, dim=-1)
|
117
|
+
return torch.cat((-x2, x1), dim=-1)
|
118
|
+
else:
|
119
|
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
120
|
+
return rearrange(
|
121
|
+
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
122
|
+
)
|
123
|
+
|
124
|
+
|
125
|
+
def apply_rotary_emb_torch(
|
126
|
+
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
127
|
+
) -> torch.Tensor:
|
128
|
+
"""
|
129
|
+
x: (batch_size, seqlen, nheads, headdim)
|
130
|
+
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
131
|
+
"""
|
132
|
+
ro_dim = cos.shape[-1] * 2
|
133
|
+
assert ro_dim <= x.shape[-1]
|
134
|
+
cos = repeat(
|
135
|
+
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
136
|
+
)
|
137
|
+
sin = repeat(
|
138
|
+
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
139
|
+
)
|
140
|
+
return torch.cat(
|
141
|
+
[
|
142
|
+
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
143
|
+
x[..., ro_dim:],
|
144
|
+
],
|
145
|
+
dim=-1,
|
146
|
+
)
|
147
|
+
|
148
|
+
|
149
|
+
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
150
|
+
t_ = t.float()
|
151
|
+
cos = freqs.cos()
|
152
|
+
sin = freqs.sin()
|
153
|
+
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
|
154
|
+
return output
|
155
|
+
|
156
|
+
|
157
|
+
class Qwen2VisionAttention(nn.Module):
|
158
|
+
|
159
|
+
def __init__(
|
160
|
+
self,
|
161
|
+
embed_dim: Optional[int] = None,
|
162
|
+
num_heads: Optional[int] = None,
|
163
|
+
projection_size: Optional[int] = None,
|
164
|
+
quant_config: Optional[QuantizationConfig] = None,
|
165
|
+
) -> None:
|
166
|
+
super().__init__()
|
167
|
+
# Per attention head and per partition values.
|
168
|
+
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
169
|
+
self.hidden_size_per_attention_head = dist_utils.divide(
|
170
|
+
projection_size, num_heads
|
171
|
+
)
|
172
|
+
self.num_attention_heads_per_partition = dist_utils.divide(
|
173
|
+
num_heads, world_size
|
174
|
+
)
|
175
|
+
|
176
|
+
self.qkv = ColumnParallelLinear(
|
177
|
+
input_size=embed_dim,
|
178
|
+
output_size=3 * projection_size,
|
179
|
+
quant_config=quant_config,
|
180
|
+
)
|
181
|
+
self.proj = RowParallelLinear(
|
182
|
+
input_size=projection_size, output_size=embed_dim, quant_config=quant_config
|
183
|
+
)
|
184
|
+
|
185
|
+
def forward(
|
186
|
+
self,
|
187
|
+
x: torch.Tensor,
|
188
|
+
cu_seqlens: torch.Tensor,
|
189
|
+
rotary_pos_emb: torch.Tensor = None,
|
190
|
+
) -> torch.Tensor:
|
191
|
+
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
192
|
+
x, _ = self.qkv(x)
|
193
|
+
|
194
|
+
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
195
|
+
new_x_shape = x.size()[:-1] + (
|
196
|
+
self.num_attention_heads_per_partition,
|
197
|
+
3 * self.hidden_size_per_attention_head,
|
198
|
+
)
|
199
|
+
x = x.view(*new_x_shape)
|
200
|
+
|
201
|
+
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
202
|
+
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
|
203
|
+
batch_size = q.shape[1]
|
204
|
+
|
205
|
+
q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)]
|
206
|
+
if rotary_pos_emb is not None:
|
207
|
+
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
208
|
+
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
209
|
+
|
210
|
+
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
211
|
+
max_seqlen = (seq_lens).max().item()
|
212
|
+
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
213
|
+
|
214
|
+
output = torch.empty_like(q)
|
215
|
+
context_attention_fwd(
|
216
|
+
q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False
|
217
|
+
)
|
218
|
+
|
219
|
+
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
220
|
+
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
221
|
+
|
222
|
+
output, _ = self.proj(context_layer)
|
223
|
+
return output
|
224
|
+
|
225
|
+
|
226
|
+
class Qwen2VisionBlock(nn.Module):
|
227
|
+
|
228
|
+
def __init__(
|
229
|
+
self,
|
230
|
+
dim: int,
|
231
|
+
num_heads: int,
|
232
|
+
mlp_ratio: float,
|
233
|
+
act_layer: Type[nn.Module] = QuickGELU,
|
234
|
+
norm_layer: Type[nn.Module] = None,
|
235
|
+
quant_config: Optional[QuantizationConfig] = None,
|
236
|
+
) -> None:
|
237
|
+
super().__init__()
|
238
|
+
if norm_layer is None:
|
239
|
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
240
|
+
self.norm1 = norm_layer(dim)
|
241
|
+
self.norm2 = norm_layer(dim)
|
242
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
243
|
+
|
244
|
+
self.attn = Qwen2VisionAttention(
|
245
|
+
embed_dim=dim,
|
246
|
+
num_heads=num_heads,
|
247
|
+
projection_size=dim,
|
248
|
+
quant_config=quant_config,
|
249
|
+
)
|
250
|
+
self.mlp = Qwen2VisionMLP(
|
251
|
+
dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config
|
252
|
+
)
|
253
|
+
|
254
|
+
def forward(
|
255
|
+
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
256
|
+
) -> torch.Tensor:
|
257
|
+
x = x + self.attn(
|
258
|
+
self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
259
|
+
)
|
260
|
+
x = x + self.mlp(self.norm2(x))
|
261
|
+
return x
|
262
|
+
|
263
|
+
|
264
|
+
class Qwen2VisionPatchEmbed(nn.Module):
|
265
|
+
|
266
|
+
def __init__(
|
267
|
+
self,
|
268
|
+
patch_size: int = 14,
|
269
|
+
temporal_patch_size: int = 2,
|
270
|
+
in_chans: int = 3,
|
271
|
+
embed_dim: int = 1152,
|
272
|
+
) -> None:
|
273
|
+
super().__init__()
|
274
|
+
self.patch_size = patch_size
|
275
|
+
self.temporal_patch_size = temporal_patch_size
|
276
|
+
self.embed_dim = embed_dim
|
277
|
+
|
278
|
+
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
279
|
+
self.proj = nn.Conv3d(
|
280
|
+
in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
|
281
|
+
)
|
282
|
+
|
283
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
284
|
+
L, C = x.shape
|
285
|
+
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
286
|
+
x = self.proj(x).view(L, self.embed_dim)
|
287
|
+
return x
|
288
|
+
|
289
|
+
|
290
|
+
class Qwen2VisionPatchMerger(nn.Module):
|
291
|
+
|
292
|
+
def __init__(
|
293
|
+
self,
|
294
|
+
d_model: int,
|
295
|
+
context_dim: int,
|
296
|
+
norm_layer: Type[nn.Module] = None,
|
297
|
+
spatial_merge_size: int = 2,
|
298
|
+
quant_config: Optional[QuantizationConfig] = None,
|
299
|
+
) -> None:
|
300
|
+
super().__init__()
|
301
|
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
302
|
+
if norm_layer is None:
|
303
|
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
304
|
+
self.ln_q = norm_layer(context_dim)
|
305
|
+
self.mlp = nn.ModuleList(
|
306
|
+
[
|
307
|
+
ColumnParallelLinear(
|
308
|
+
self.hidden_size,
|
309
|
+
self.hidden_size,
|
310
|
+
bias=True,
|
311
|
+
quant_config=quant_config,
|
312
|
+
),
|
313
|
+
nn.GELU(),
|
314
|
+
RowParallelLinear(
|
315
|
+
self.hidden_size, d_model, bias=True, quant_config=quant_config
|
316
|
+
),
|
317
|
+
]
|
318
|
+
)
|
319
|
+
|
320
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
321
|
+
x = self.ln_q(x)
|
322
|
+
x = x.view(-1, self.hidden_size)
|
323
|
+
|
324
|
+
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
325
|
+
x_parallel, _ = mlp_fc1(x)
|
326
|
+
x_parallel = mlp_act(x_parallel)
|
327
|
+
out, _ = mlp_fc2(x_parallel)
|
328
|
+
return out
|
329
|
+
|
330
|
+
|
331
|
+
class Qwen2VisionRotaryEmbedding(nn.Module):
|
332
|
+
|
333
|
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
334
|
+
super().__init__()
|
335
|
+
self.dim = dim
|
336
|
+
self.theta = theta
|
337
|
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
338
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
339
|
+
self._seq_len_cached = 0
|
340
|
+
self._freqs_cached = None
|
341
|
+
|
342
|
+
def update_freqs_cache(self, seqlen: int) -> None:
|
343
|
+
if seqlen > self._seq_len_cached:
|
344
|
+
seqlen *= 2
|
345
|
+
self._seq_len_cached = seqlen
|
346
|
+
self.inv_freq = 1.0 / (
|
347
|
+
self.theta
|
348
|
+
** (
|
349
|
+
torch.arange(
|
350
|
+
0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
|
351
|
+
)
|
352
|
+
/ self.dim
|
353
|
+
)
|
354
|
+
)
|
355
|
+
seq = torch.arange(
|
356
|
+
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
357
|
+
)
|
358
|
+
freqs = torch.outer(seq, self.inv_freq)
|
359
|
+
self._freqs_cached = freqs
|
360
|
+
|
361
|
+
def forward(self, seqlen: int) -> torch.Tensor:
|
362
|
+
self.update_freqs_cache(seqlen)
|
363
|
+
return self._freqs_cached[:seqlen]
|
364
|
+
|
365
|
+
|
366
|
+
class Qwen2VisionTransformer(nn.Module):
|
367
|
+
|
368
|
+
def __init__(
|
369
|
+
self,
|
370
|
+
vision_config: Qwen2VLVisionConfig,
|
371
|
+
norm_eps: float = 1e-6,
|
372
|
+
quant_config: Optional[QuantizationConfig] = None,
|
373
|
+
) -> None:
|
374
|
+
super().__init__()
|
375
|
+
|
376
|
+
patch_size: int = vision_config.patch_size
|
377
|
+
temporal_patch_size: int = vision_config.temporal_patch_size
|
378
|
+
spatial_merge_size: int = vision_config.spatial_merge_size
|
379
|
+
in_chans: int = vision_config.in_chans
|
380
|
+
hidden_size: int = vision_config.hidden_size
|
381
|
+
embed_dim: int = vision_config.embed_dim
|
382
|
+
depth: int = vision_config.depth
|
383
|
+
num_heads: int = vision_config.num_heads
|
384
|
+
mlp_ratio: float = vision_config.mlp_ratio
|
385
|
+
|
386
|
+
self.spatial_merge_size = spatial_merge_size
|
387
|
+
|
388
|
+
self.patch_embed = Qwen2VisionPatchEmbed(
|
389
|
+
patch_size=patch_size,
|
390
|
+
temporal_patch_size=temporal_patch_size,
|
391
|
+
in_chans=in_chans,
|
392
|
+
embed_dim=embed_dim,
|
393
|
+
)
|
394
|
+
|
395
|
+
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
396
|
+
head_dim = embed_dim // num_heads
|
397
|
+
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
|
398
|
+
|
399
|
+
self.blocks = nn.ModuleList(
|
400
|
+
[
|
401
|
+
Qwen2VisionBlock(
|
402
|
+
dim=embed_dim,
|
403
|
+
num_heads=num_heads,
|
404
|
+
mlp_ratio=mlp_ratio,
|
405
|
+
norm_layer=norm_layer,
|
406
|
+
quant_config=quant_config,
|
407
|
+
)
|
408
|
+
for _ in range(depth)
|
409
|
+
]
|
410
|
+
)
|
411
|
+
self.merger = Qwen2VisionPatchMerger(
|
412
|
+
d_model=hidden_size,
|
413
|
+
context_dim=embed_dim,
|
414
|
+
norm_layer=norm_layer,
|
415
|
+
quant_config=quant_config,
|
416
|
+
)
|
417
|
+
|
418
|
+
@property
|
419
|
+
def dtype(self) -> torch.dtype:
|
420
|
+
return self.blocks[0].mlp.fc2.weight.dtype
|
421
|
+
|
422
|
+
@property
|
423
|
+
def device(self) -> torch.device:
|
424
|
+
return self.blocks[0].mlp.fc2.weight.device
|
425
|
+
|
426
|
+
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
427
|
+
pos_ids = []
|
428
|
+
for t, h, w in grid_thw:
|
429
|
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
430
|
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
431
|
+
hpos_ids = (
|
432
|
+
hpos_ids.reshape(
|
433
|
+
h // self.spatial_merge_size,
|
434
|
+
self.spatial_merge_size,
|
435
|
+
w // self.spatial_merge_size,
|
436
|
+
self.spatial_merge_size,
|
437
|
+
)
|
438
|
+
.permute(0, 2, 1, 3)
|
439
|
+
.flatten()
|
440
|
+
)
|
441
|
+
wpos_ids = (
|
442
|
+
wpos_ids.reshape(
|
443
|
+
h // self.spatial_merge_size,
|
444
|
+
self.spatial_merge_size,
|
445
|
+
w // self.spatial_merge_size,
|
446
|
+
self.spatial_merge_size,
|
447
|
+
)
|
448
|
+
.permute(0, 2, 1, 3)
|
449
|
+
.flatten()
|
450
|
+
)
|
451
|
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
452
|
+
pos_ids = torch.cat(pos_ids, dim=0)
|
453
|
+
max_grid_size = grid_thw[:, 1:].max()
|
454
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
455
|
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
456
|
+
return rotary_pos_emb
|
457
|
+
|
458
|
+
def forward(
|
459
|
+
self,
|
460
|
+
x: torch.Tensor,
|
461
|
+
grid_thw: torch.Tensor,
|
462
|
+
) -> torch.Tensor:
|
463
|
+
# patchify
|
464
|
+
x = x.to(device=self.device, dtype=self.dtype)
|
465
|
+
x = self.patch_embed(x)
|
466
|
+
|
467
|
+
# compute position embedding
|
468
|
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
469
|
+
|
470
|
+
# compute cu_seqlens
|
471
|
+
cu_seqlens = torch.repeat_interleave(
|
472
|
+
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
473
|
+
).cumsum(dim=0, dtype=torch.int32)
|
474
|
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
475
|
+
|
476
|
+
# transformers
|
477
|
+
x = x.unsqueeze(1)
|
478
|
+
for blk in self.blocks:
|
479
|
+
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
480
|
+
|
481
|
+
# adapter
|
482
|
+
x = self.merger(x)
|
483
|
+
return x
|
484
|
+
|
485
|
+
|
486
|
+
cached_get_processor = lru_cache(get_processor)
|
487
|
+
|
488
|
+
|
489
|
+
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
490
|
+
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
491
|
+
processor = cached_get_processor(self.config._name_or_path)
|
492
|
+
grid_t, grid_h, grid_w = image_grid_thw
|
493
|
+
num_image_tokens = (
|
494
|
+
grid_t
|
495
|
+
* grid_h
|
496
|
+
* grid_w
|
497
|
+
// processor.image_processor.merge_size
|
498
|
+
// processor.image_processor.merge_size
|
499
|
+
)
|
500
|
+
return num_image_tokens
|
501
|
+
|
502
|
+
# Use grid_t * grid_w * grid_h to pad tokens for each image
|
503
|
+
# and replaced padding by unique image hash
|
504
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
505
|
+
image_grid_thws = image_inputs.image_grid_thws
|
506
|
+
pad_values = image_inputs.pad_values
|
507
|
+
|
508
|
+
image_indices = [
|
509
|
+
idx
|
510
|
+
for idx, token in enumerate(input_ids)
|
511
|
+
if token == self.config.image_token_id
|
512
|
+
]
|
513
|
+
image_inputs.image_offsets = []
|
514
|
+
|
515
|
+
input_ids_with_image = []
|
516
|
+
for image_cnt, _ in enumerate(image_grid_thws):
|
517
|
+
num_image_tokens = self.calculate_num_image_tokens(
|
518
|
+
image_grid_thws[image_cnt]
|
519
|
+
)
|
520
|
+
if image_cnt == 0:
|
521
|
+
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
522
|
+
else:
|
523
|
+
non_image_tokens = input_ids[
|
524
|
+
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
525
|
+
]
|
526
|
+
input_ids_with_image.extend(non_image_tokens)
|
527
|
+
image_inputs.image_offsets.append(len(input_ids_with_image))
|
528
|
+
pad_ids = pad_values * (
|
529
|
+
(num_image_tokens + len(pad_values)) // len(pad_values)
|
530
|
+
)
|
531
|
+
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
532
|
+
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
533
|
+
|
534
|
+
return input_ids_with_image
|
535
|
+
|
536
|
+
def __init__(
|
537
|
+
self,
|
538
|
+
config: Qwen2VLConfig,
|
539
|
+
multimodal_config: MultiModalConfig,
|
540
|
+
cache_config: Optional[CacheConfig] = None,
|
541
|
+
quant_config: Optional[QuantizationConfig] = None,
|
542
|
+
) -> None:
|
543
|
+
super().__init__()
|
544
|
+
|
545
|
+
self.config = config
|
546
|
+
self.multimodal_config = multimodal_config
|
547
|
+
|
548
|
+
self.visual = Qwen2VisionTransformer(
|
549
|
+
config.vision_config,
|
550
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
551
|
+
# NOTE: Qwen2-VL vision encoder does not support any
|
552
|
+
# quantization method now.
|
553
|
+
quant_config=None,
|
554
|
+
)
|
555
|
+
|
556
|
+
self.model = Qwen2Model(config, quant_config)
|
557
|
+
|
558
|
+
if config.tie_word_embeddings:
|
559
|
+
self.lm_head = self.model.embed_tokens
|
560
|
+
else:
|
561
|
+
self.lm_head = ParallelLMHead(
|
562
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
563
|
+
)
|
564
|
+
|
565
|
+
self.logits_processor = LogitsProcessor(config)
|
566
|
+
|
567
|
+
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
568
|
+
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
569
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
|
570
|
+
return image_embeds
|
571
|
+
|
572
|
+
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
573
|
+
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
|
574
|
+
video_embeds = self.visual(
|
575
|
+
pixel_values_videos, grid_thw=video_input["video_grid_thw"]
|
576
|
+
)
|
577
|
+
return video_embeds
|
578
|
+
|
579
|
+
def forward(
|
580
|
+
self,
|
581
|
+
input_ids: torch.Tensor,
|
582
|
+
positions: torch.Tensor,
|
583
|
+
forward_batch: ForwardBatch,
|
584
|
+
):
|
585
|
+
"""Run forward pass for Qwen2-VL.
|
586
|
+
|
587
|
+
Args:
|
588
|
+
input_ids: Flattened (concatenated) input_ids corresponding to a
|
589
|
+
batch.
|
590
|
+
positions: Flattened (concatenated) position ids corresponding to a
|
591
|
+
batch.
|
592
|
+
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
593
|
+
opensource models), the shape will be `(3, seq_len)`,
|
594
|
+
otherwise it will be `(seq_len,).
|
595
|
+
(Use input_metadata.mrope_positions to replace it)
|
596
|
+
pixel_values: Pixel values to be fed to a model.
|
597
|
+
`None` if no images are passed.
|
598
|
+
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
599
|
+
`None` if no images are passed.
|
600
|
+
"""
|
601
|
+
image_inputs = None
|
602
|
+
if forward_batch.image_inputs is not None:
|
603
|
+
image_inputs = [
|
604
|
+
img for img in forward_batch.image_inputs if img is not None
|
605
|
+
]
|
606
|
+
|
607
|
+
positions = forward_batch.mrope_positions
|
608
|
+
if (
|
609
|
+
forward_batch.forward_mode.is_decode()
|
610
|
+
or image_inputs is None
|
611
|
+
or len(image_inputs) == 0
|
612
|
+
):
|
613
|
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
614
|
+
else:
|
615
|
+
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
616
|
+
assert positions.ndim == 2 and positions.size(0) == 3, (
|
617
|
+
"multimodal section rotary embedding requires "
|
618
|
+
f"(3, seq_len) positions, but got {positions.size()}"
|
619
|
+
)
|
620
|
+
|
621
|
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
622
|
+
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
623
|
+
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
624
|
+
for i, image in enumerate(forward_batch.image_inputs):
|
625
|
+
if image == None:
|
626
|
+
continue
|
627
|
+
start_idx = extend_start_loc_cpu[i]
|
628
|
+
prefix_len = prefix_lens_cpu[i]
|
629
|
+
|
630
|
+
pixel_values = torch.tensor(image.pixel_values, device="cuda")
|
631
|
+
image_grid_thws = torch.tensor(
|
632
|
+
np.array(image.image_grid_thws), device="cuda"
|
633
|
+
)
|
634
|
+
image_offsets = image.image_offsets
|
635
|
+
image_input = Qwen2VLImageInputs(
|
636
|
+
pixel_values=pixel_values, image_grid_thw=image_grid_thws
|
637
|
+
)
|
638
|
+
image_embeds = self._process_image_input(image_input)
|
639
|
+
|
640
|
+
image_embeds_offset = 0
|
641
|
+
for idx, image_offset in enumerate(image_offsets):
|
642
|
+
if image_offset < prefix_len:
|
643
|
+
continue
|
644
|
+
num_image_tokens = self.calculate_num_image_tokens(
|
645
|
+
image_grid_thws[idx]
|
646
|
+
)
|
647
|
+
left_idx = start_idx + (image_offset - prefix_len)
|
648
|
+
right_idx = (
|
649
|
+
start_idx + (image_offset - prefix_len) + num_image_tokens
|
650
|
+
)
|
651
|
+
inputs_embeds[left_idx:right_idx] = image_embeds[
|
652
|
+
image_embeds_offset : image_embeds_offset + num_image_tokens
|
653
|
+
]
|
654
|
+
image_embeds_offset += num_image_tokens
|
655
|
+
|
656
|
+
input_ids = None
|
657
|
+
|
658
|
+
hidden_states = self.model(
|
659
|
+
input_ids=input_ids,
|
660
|
+
positions=positions,
|
661
|
+
forward_batch=forward_batch,
|
662
|
+
input_embeds=inputs_embeds,
|
663
|
+
)
|
664
|
+
return self.logits_processor(
|
665
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
666
|
+
)
|
667
|
+
|
668
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
669
|
+
stacked_params_mapping = [
|
670
|
+
# (param_name, shard_name, shard_id)
|
671
|
+
("qkv_proj", "q_proj", "q"),
|
672
|
+
("qkv_proj", "k_proj", "k"),
|
673
|
+
("qkv_proj", "v_proj", "v"),
|
674
|
+
("gate_up_proj", "up_proj", 1),
|
675
|
+
("gate_up_proj", "gate_proj", 0),
|
676
|
+
]
|
677
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
678
|
+
for name, loaded_weight in weights:
|
679
|
+
if "rotary_emb.inv_freq" in name:
|
680
|
+
continue
|
681
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
682
|
+
continue
|
683
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
684
|
+
if weight_name not in name:
|
685
|
+
continue
|
686
|
+
name = name.replace(weight_name, param_name)
|
687
|
+
# Skip loading extra bias for GPTQ models.
|
688
|
+
if name.endswith(".bias") and name not in params_dict:
|
689
|
+
continue
|
690
|
+
param = params_dict[name]
|
691
|
+
weight_loader = param.weight_loader
|
692
|
+
weight_loader(param, loaded_weight, shard_id)
|
693
|
+
break
|
694
|
+
else:
|
695
|
+
if "visual" in name and "qkv.weight" in name:
|
696
|
+
visual_num_heads = self.config.vision_config.num_heads
|
697
|
+
visual_embed_dim = self.config.vision_config.embed_dim
|
698
|
+
head_size = visual_embed_dim // visual_num_heads
|
699
|
+
loaded_weight = loaded_weight.view(
|
700
|
+
3, visual_num_heads, head_size, visual_embed_dim
|
701
|
+
)
|
702
|
+
loaded_weight = loaded_weight.transpose(0, 1)
|
703
|
+
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
704
|
+
elif "visual" in name and "qkv.bias" in name:
|
705
|
+
visual_num_heads = self.config.vision_config.num_heads
|
706
|
+
visual_embed_dim = self.config.vision_config.embed_dim
|
707
|
+
head_size = visual_embed_dim // visual_num_heads
|
708
|
+
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
709
|
+
loaded_weight = loaded_weight.transpose(0, 1)
|
710
|
+
loaded_weight = loaded_weight.reshape(-1)
|
711
|
+
try:
|
712
|
+
# Skip loading extra bias for GPTQ models.
|
713
|
+
if name.endswith(".bias") and name not in params_dict:
|
714
|
+
continue
|
715
|
+
param = params_dict[name]
|
716
|
+
except KeyError:
|
717
|
+
print(params_dict.keys())
|
718
|
+
raise
|
719
|
+
|
720
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
721
|
+
weight_loader(param, loaded_weight)
|
722
|
+
|
723
|
+
|
724
|
+
EntryClass = Qwen2VLForConditionalGeneration
|