sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_latency.py +2 -1
  2. sglang/lang/chat_template.py +17 -0
  3. sglang/launch_server_llavavid.py +1 -1
  4. sglang/srt/configs/__init__.py +3 -0
  5. sglang/srt/configs/model_config.py +27 -2
  6. sglang/srt/configs/qwen2vl.py +133 -0
  7. sglang/srt/constrained/fsm_cache.py +10 -3
  8. sglang/srt/conversation.py +27 -0
  9. sglang/srt/hf_transformers_utils.py +16 -1
  10. sglang/srt/layers/attention/__init__.py +16 -5
  11. sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
  12. sglang/srt/layers/attention/flashinfer_backend.py +174 -54
  13. sglang/srt/layers/attention/triton_backend.py +22 -6
  14. sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
  15. sglang/srt/layers/linear.py +89 -63
  16. sglang/srt/layers/logits_processor.py +5 -5
  17. sglang/srt/layers/rotary_embedding.py +112 -0
  18. sglang/srt/layers/sampler.py +51 -39
  19. sglang/srt/lora/lora.py +3 -1
  20. sglang/srt/managers/data_parallel_controller.py +1 -1
  21. sglang/srt/managers/detokenizer_manager.py +4 -0
  22. sglang/srt/managers/image_processor.py +186 -13
  23. sglang/srt/managers/io_struct.py +10 -0
  24. sglang/srt/managers/schedule_batch.py +238 -68
  25. sglang/srt/managers/scheduler.py +69 -50
  26. sglang/srt/managers/tokenizer_manager.py +24 -4
  27. sglang/srt/managers/tp_worker.py +26 -111
  28. sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
  29. sglang/srt/mem_cache/memory_pool.py +56 -10
  30. sglang/srt/mem_cache/radix_cache.py +4 -3
  31. sglang/srt/model_executor/cuda_graph_runner.py +87 -28
  32. sglang/srt/model_executor/forward_batch_info.py +83 -3
  33. sglang/srt/model_executor/model_runner.py +32 -11
  34. sglang/srt/models/chatglm.py +3 -3
  35. sglang/srt/models/deepseek_v2.py +2 -2
  36. sglang/srt/models/mllama.py +1004 -0
  37. sglang/srt/models/qwen2_vl.py +724 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  39. sglang/srt/sampling/sampling_batch_info.py +13 -3
  40. sglang/srt/sampling/sampling_params.py +5 -7
  41. sglang/srt/server.py +12 -0
  42. sglang/srt/server_args.py +10 -0
  43. sglang/srt/utils.py +22 -0
  44. sglang/test/run_eval.py +2 -0
  45. sglang/test/runners.py +20 -1
  46. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  47. sglang/test/test_utils.py +100 -3
  48. sglang/version.py +1 -1
  49. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
  50. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
  51. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
  52. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
  53. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,724 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
4
+ # Copyright 2024 The Qwen team.
5
+ # Copyright 2023 The vLLM team.
6
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
7
+ #
8
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
9
+ # and OPT implementations in this library. It has been modified from its
10
+ # original forms to accommodate minor architectural differences compared
11
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12
+ #
13
+ # Licensed under the Apache License, Version 2.0 (the "License");
14
+ # you may not use this file except in compliance with the License.
15
+ # You may obtain a copy of the License at
16
+ #
17
+ # http://www.apache.org/licenses/LICENSE-2.0
18
+ #
19
+ # Unless required by applicable law or agreed to in writing, software
20
+ # distributed under the License is distributed on an "AS IS" BASIS,
21
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ # See the License for the specific language governing permissions and
23
+ # limitations under the License.
24
+ """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
25
+ from functools import lru_cache, partial
26
+ from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union
27
+
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from einops import rearrange, repeat
33
+ from vllm.config import CacheConfig, MultiModalConfig
34
+ from vllm.distributed import parallel_state
35
+ from vllm.distributed import utils as dist_utils
36
+ from vllm.logger import init_logger
37
+ from vllm.model_executor.layers.activation import QuickGELU
38
+ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
39
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
+ from vllm.model_executor.models.interfaces import SupportsMultiModal
41
+
42
+ from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
43
+ from sglang.srt.hf_transformers_utils import get_processor
44
+ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
45
+ context_attention_fwd,
46
+ )
47
+ from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
48
+ from sglang.srt.layers.logits_processor import LogitsProcessor
49
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
50
+ from sglang.srt.managers.schedule_batch import ImageInputs
51
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
52
+ from sglang.srt.models.qwen2 import Qwen2Model
53
+
54
+ logger = init_logger(__name__)
55
+
56
+ # === Vision Inputs === #
57
+
58
+
59
+ class Qwen2VLImageInputs(TypedDict):
60
+ pixel_values: torch.Tensor
61
+ """Shape:
62
+ `(num_patches, num_channels * patch_size * patch_size)`
63
+ """
64
+
65
+ image_grid_thw: torch.Tensor
66
+ """Shape: `(num_images, 3)`
67
+
68
+ This should be in `(grid_t, grid_h, grid_w)` format.
69
+ """
70
+
71
+
72
+ class Qwen2VLVideoInputs(TypedDict):
73
+ pixel_values_videos: torch.Tensor
74
+ """Shape:
75
+ `(num_patches,
76
+ num_channels * temporal_patch_size * patch_size * patch_size)`
77
+ """
78
+
79
+ video_grid_thw: torch.Tensor
80
+ """Shape: `(num_videos, 3)`
81
+
82
+ This should be in `(grid_t, grid_h, grid_w)` format.
83
+ """
84
+
85
+
86
+ # === Vision Encoder === #
87
+
88
+
89
+ class Qwen2VisionMLP(nn.Module):
90
+
91
+ def __init__(
92
+ self,
93
+ in_features: int,
94
+ hidden_features: int = None,
95
+ act_layer: Type[nn.Module] = QuickGELU,
96
+ quant_config: Optional[QuantizationConfig] = None,
97
+ ):
98
+ super().__init__()
99
+ self.fc1 = ColumnParallelLinear(
100
+ in_features, hidden_features, quant_config=quant_config
101
+ )
102
+ self.act = act_layer()
103
+ self.fc2 = RowParallelLinear(
104
+ hidden_features, in_features, quant_config=quant_config
105
+ )
106
+
107
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
108
+ x_parallel, _ = self.fc1(x)
109
+ x_parallel = self.act(x_parallel)
110
+ x, _ = self.fc2(x_parallel)
111
+ return x
112
+
113
+
114
+ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
115
+ if not interleaved:
116
+ x1, x2 = x.chunk(2, dim=-1)
117
+ return torch.cat((-x2, x1), dim=-1)
118
+ else:
119
+ x1, x2 = x[..., ::2], x[..., 1::2]
120
+ return rearrange(
121
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
122
+ )
123
+
124
+
125
+ def apply_rotary_emb_torch(
126
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
127
+ ) -> torch.Tensor:
128
+ """
129
+ x: (batch_size, seqlen, nheads, headdim)
130
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
131
+ """
132
+ ro_dim = cos.shape[-1] * 2
133
+ assert ro_dim <= x.shape[-1]
134
+ cos = repeat(
135
+ cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
136
+ )
137
+ sin = repeat(
138
+ sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
139
+ )
140
+ return torch.cat(
141
+ [
142
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
143
+ x[..., ro_dim:],
144
+ ],
145
+ dim=-1,
146
+ )
147
+
148
+
149
+ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
150
+ t_ = t.float()
151
+ cos = freqs.cos()
152
+ sin = freqs.sin()
153
+ output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
154
+ return output
155
+
156
+
157
+ class Qwen2VisionAttention(nn.Module):
158
+
159
+ def __init__(
160
+ self,
161
+ embed_dim: Optional[int] = None,
162
+ num_heads: Optional[int] = None,
163
+ projection_size: Optional[int] = None,
164
+ quant_config: Optional[QuantizationConfig] = None,
165
+ ) -> None:
166
+ super().__init__()
167
+ # Per attention head and per partition values.
168
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
169
+ self.hidden_size_per_attention_head = dist_utils.divide(
170
+ projection_size, num_heads
171
+ )
172
+ self.num_attention_heads_per_partition = dist_utils.divide(
173
+ num_heads, world_size
174
+ )
175
+
176
+ self.qkv = ColumnParallelLinear(
177
+ input_size=embed_dim,
178
+ output_size=3 * projection_size,
179
+ quant_config=quant_config,
180
+ )
181
+ self.proj = RowParallelLinear(
182
+ input_size=projection_size, output_size=embed_dim, quant_config=quant_config
183
+ )
184
+
185
+ def forward(
186
+ self,
187
+ x: torch.Tensor,
188
+ cu_seqlens: torch.Tensor,
189
+ rotary_pos_emb: torch.Tensor = None,
190
+ ) -> torch.Tensor:
191
+ # [s, b, c] --> [s, b, head * 3 * head_dim]
192
+ x, _ = self.qkv(x)
193
+
194
+ # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
195
+ new_x_shape = x.size()[:-1] + (
196
+ self.num_attention_heads_per_partition,
197
+ 3 * self.hidden_size_per_attention_head,
198
+ )
199
+ x = x.view(*new_x_shape)
200
+
201
+ # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
202
+ q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
203
+ batch_size = q.shape[1]
204
+
205
+ q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)]
206
+ if rotary_pos_emb is not None:
207
+ q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
208
+ k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
209
+
210
+ seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
211
+ max_seqlen = (seq_lens).max().item()
212
+ q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
213
+
214
+ output = torch.empty_like(q)
215
+ context_attention_fwd(
216
+ q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False
217
+ )
218
+
219
+ context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
220
+ context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
221
+
222
+ output, _ = self.proj(context_layer)
223
+ return output
224
+
225
+
226
+ class Qwen2VisionBlock(nn.Module):
227
+
228
+ def __init__(
229
+ self,
230
+ dim: int,
231
+ num_heads: int,
232
+ mlp_ratio: float,
233
+ act_layer: Type[nn.Module] = QuickGELU,
234
+ norm_layer: Type[nn.Module] = None,
235
+ quant_config: Optional[QuantizationConfig] = None,
236
+ ) -> None:
237
+ super().__init__()
238
+ if norm_layer is None:
239
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
240
+ self.norm1 = norm_layer(dim)
241
+ self.norm2 = norm_layer(dim)
242
+ mlp_hidden_dim = int(dim * mlp_ratio)
243
+
244
+ self.attn = Qwen2VisionAttention(
245
+ embed_dim=dim,
246
+ num_heads=num_heads,
247
+ projection_size=dim,
248
+ quant_config=quant_config,
249
+ )
250
+ self.mlp = Qwen2VisionMLP(
251
+ dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config
252
+ )
253
+
254
+ def forward(
255
+ self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
256
+ ) -> torch.Tensor:
257
+ x = x + self.attn(
258
+ self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
259
+ )
260
+ x = x + self.mlp(self.norm2(x))
261
+ return x
262
+
263
+
264
+ class Qwen2VisionPatchEmbed(nn.Module):
265
+
266
+ def __init__(
267
+ self,
268
+ patch_size: int = 14,
269
+ temporal_patch_size: int = 2,
270
+ in_chans: int = 3,
271
+ embed_dim: int = 1152,
272
+ ) -> None:
273
+ super().__init__()
274
+ self.patch_size = patch_size
275
+ self.temporal_patch_size = temporal_patch_size
276
+ self.embed_dim = embed_dim
277
+
278
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
279
+ self.proj = nn.Conv3d(
280
+ in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
281
+ )
282
+
283
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
284
+ L, C = x.shape
285
+ x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
286
+ x = self.proj(x).view(L, self.embed_dim)
287
+ return x
288
+
289
+
290
+ class Qwen2VisionPatchMerger(nn.Module):
291
+
292
+ def __init__(
293
+ self,
294
+ d_model: int,
295
+ context_dim: int,
296
+ norm_layer: Type[nn.Module] = None,
297
+ spatial_merge_size: int = 2,
298
+ quant_config: Optional[QuantizationConfig] = None,
299
+ ) -> None:
300
+ super().__init__()
301
+ self.hidden_size = context_dim * (spatial_merge_size**2)
302
+ if norm_layer is None:
303
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
304
+ self.ln_q = norm_layer(context_dim)
305
+ self.mlp = nn.ModuleList(
306
+ [
307
+ ColumnParallelLinear(
308
+ self.hidden_size,
309
+ self.hidden_size,
310
+ bias=True,
311
+ quant_config=quant_config,
312
+ ),
313
+ nn.GELU(),
314
+ RowParallelLinear(
315
+ self.hidden_size, d_model, bias=True, quant_config=quant_config
316
+ ),
317
+ ]
318
+ )
319
+
320
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
321
+ x = self.ln_q(x)
322
+ x = x.view(-1, self.hidden_size)
323
+
324
+ mlp_fc1, mlp_act, mlp_fc2 = self.mlp
325
+ x_parallel, _ = mlp_fc1(x)
326
+ x_parallel = mlp_act(x_parallel)
327
+ out, _ = mlp_fc2(x_parallel)
328
+ return out
329
+
330
+
331
+ class Qwen2VisionRotaryEmbedding(nn.Module):
332
+
333
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
334
+ super().__init__()
335
+ self.dim = dim
336
+ self.theta = theta
337
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
338
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
339
+ self._seq_len_cached = 0
340
+ self._freqs_cached = None
341
+
342
+ def update_freqs_cache(self, seqlen: int) -> None:
343
+ if seqlen > self._seq_len_cached:
344
+ seqlen *= 2
345
+ self._seq_len_cached = seqlen
346
+ self.inv_freq = 1.0 / (
347
+ self.theta
348
+ ** (
349
+ torch.arange(
350
+ 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
351
+ )
352
+ / self.dim
353
+ )
354
+ )
355
+ seq = torch.arange(
356
+ seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
357
+ )
358
+ freqs = torch.outer(seq, self.inv_freq)
359
+ self._freqs_cached = freqs
360
+
361
+ def forward(self, seqlen: int) -> torch.Tensor:
362
+ self.update_freqs_cache(seqlen)
363
+ return self._freqs_cached[:seqlen]
364
+
365
+
366
+ class Qwen2VisionTransformer(nn.Module):
367
+
368
+ def __init__(
369
+ self,
370
+ vision_config: Qwen2VLVisionConfig,
371
+ norm_eps: float = 1e-6,
372
+ quant_config: Optional[QuantizationConfig] = None,
373
+ ) -> None:
374
+ super().__init__()
375
+
376
+ patch_size: int = vision_config.patch_size
377
+ temporal_patch_size: int = vision_config.temporal_patch_size
378
+ spatial_merge_size: int = vision_config.spatial_merge_size
379
+ in_chans: int = vision_config.in_chans
380
+ hidden_size: int = vision_config.hidden_size
381
+ embed_dim: int = vision_config.embed_dim
382
+ depth: int = vision_config.depth
383
+ num_heads: int = vision_config.num_heads
384
+ mlp_ratio: float = vision_config.mlp_ratio
385
+
386
+ self.spatial_merge_size = spatial_merge_size
387
+
388
+ self.patch_embed = Qwen2VisionPatchEmbed(
389
+ patch_size=patch_size,
390
+ temporal_patch_size=temporal_patch_size,
391
+ in_chans=in_chans,
392
+ embed_dim=embed_dim,
393
+ )
394
+
395
+ norm_layer = partial(nn.LayerNorm, eps=norm_eps)
396
+ head_dim = embed_dim // num_heads
397
+ self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
398
+
399
+ self.blocks = nn.ModuleList(
400
+ [
401
+ Qwen2VisionBlock(
402
+ dim=embed_dim,
403
+ num_heads=num_heads,
404
+ mlp_ratio=mlp_ratio,
405
+ norm_layer=norm_layer,
406
+ quant_config=quant_config,
407
+ )
408
+ for _ in range(depth)
409
+ ]
410
+ )
411
+ self.merger = Qwen2VisionPatchMerger(
412
+ d_model=hidden_size,
413
+ context_dim=embed_dim,
414
+ norm_layer=norm_layer,
415
+ quant_config=quant_config,
416
+ )
417
+
418
+ @property
419
+ def dtype(self) -> torch.dtype:
420
+ return self.blocks[0].mlp.fc2.weight.dtype
421
+
422
+ @property
423
+ def device(self) -> torch.device:
424
+ return self.blocks[0].mlp.fc2.weight.device
425
+
426
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
427
+ pos_ids = []
428
+ for t, h, w in grid_thw:
429
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
430
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
431
+ hpos_ids = (
432
+ hpos_ids.reshape(
433
+ h // self.spatial_merge_size,
434
+ self.spatial_merge_size,
435
+ w // self.spatial_merge_size,
436
+ self.spatial_merge_size,
437
+ )
438
+ .permute(0, 2, 1, 3)
439
+ .flatten()
440
+ )
441
+ wpos_ids = (
442
+ wpos_ids.reshape(
443
+ h // self.spatial_merge_size,
444
+ self.spatial_merge_size,
445
+ w // self.spatial_merge_size,
446
+ self.spatial_merge_size,
447
+ )
448
+ .permute(0, 2, 1, 3)
449
+ .flatten()
450
+ )
451
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
452
+ pos_ids = torch.cat(pos_ids, dim=0)
453
+ max_grid_size = grid_thw[:, 1:].max()
454
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
455
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
456
+ return rotary_pos_emb
457
+
458
+ def forward(
459
+ self,
460
+ x: torch.Tensor,
461
+ grid_thw: torch.Tensor,
462
+ ) -> torch.Tensor:
463
+ # patchify
464
+ x = x.to(device=self.device, dtype=self.dtype)
465
+ x = self.patch_embed(x)
466
+
467
+ # compute position embedding
468
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
469
+
470
+ # compute cu_seqlens
471
+ cu_seqlens = torch.repeat_interleave(
472
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
473
+ ).cumsum(dim=0, dtype=torch.int32)
474
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
475
+
476
+ # transformers
477
+ x = x.unsqueeze(1)
478
+ for blk in self.blocks:
479
+ x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
480
+
481
+ # adapter
482
+ x = self.merger(x)
483
+ return x
484
+
485
+
486
+ cached_get_processor = lru_cache(get_processor)
487
+
488
+
489
+ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
490
+ def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
491
+ processor = cached_get_processor(self.config._name_or_path)
492
+ grid_t, grid_h, grid_w = image_grid_thw
493
+ num_image_tokens = (
494
+ grid_t
495
+ * grid_h
496
+ * grid_w
497
+ // processor.image_processor.merge_size
498
+ // processor.image_processor.merge_size
499
+ )
500
+ return num_image_tokens
501
+
502
+ # Use grid_t * grid_w * grid_h to pad tokens for each image
503
+ # and replaced padding by unique image hash
504
+ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
505
+ image_grid_thws = image_inputs.image_grid_thws
506
+ pad_values = image_inputs.pad_values
507
+
508
+ image_indices = [
509
+ idx
510
+ for idx, token in enumerate(input_ids)
511
+ if token == self.config.image_token_id
512
+ ]
513
+ image_inputs.image_offsets = []
514
+
515
+ input_ids_with_image = []
516
+ for image_cnt, _ in enumerate(image_grid_thws):
517
+ num_image_tokens = self.calculate_num_image_tokens(
518
+ image_grid_thws[image_cnt]
519
+ )
520
+ if image_cnt == 0:
521
+ non_image_tokens = input_ids[: image_indices[image_cnt]]
522
+ else:
523
+ non_image_tokens = input_ids[
524
+ image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
525
+ ]
526
+ input_ids_with_image.extend(non_image_tokens)
527
+ image_inputs.image_offsets.append(len(input_ids_with_image))
528
+ pad_ids = pad_values * (
529
+ (num_image_tokens + len(pad_values)) // len(pad_values)
530
+ )
531
+ input_ids_with_image.extend(pad_ids[:num_image_tokens])
532
+ input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
533
+
534
+ return input_ids_with_image
535
+
536
+ def __init__(
537
+ self,
538
+ config: Qwen2VLConfig,
539
+ multimodal_config: MultiModalConfig,
540
+ cache_config: Optional[CacheConfig] = None,
541
+ quant_config: Optional[QuantizationConfig] = None,
542
+ ) -> None:
543
+ super().__init__()
544
+
545
+ self.config = config
546
+ self.multimodal_config = multimodal_config
547
+
548
+ self.visual = Qwen2VisionTransformer(
549
+ config.vision_config,
550
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
551
+ # NOTE: Qwen2-VL vision encoder does not support any
552
+ # quantization method now.
553
+ quant_config=None,
554
+ )
555
+
556
+ self.model = Qwen2Model(config, quant_config)
557
+
558
+ if config.tie_word_embeddings:
559
+ self.lm_head = self.model.embed_tokens
560
+ else:
561
+ self.lm_head = ParallelLMHead(
562
+ config.vocab_size, config.hidden_size, quant_config=quant_config
563
+ )
564
+
565
+ self.logits_processor = LogitsProcessor(config)
566
+
567
+ def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
568
+ pixel_values = image_input["pixel_values"].type(self.visual.dtype)
569
+ image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
570
+ return image_embeds
571
+
572
+ def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
573
+ pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
574
+ video_embeds = self.visual(
575
+ pixel_values_videos, grid_thw=video_input["video_grid_thw"]
576
+ )
577
+ return video_embeds
578
+
579
+ def forward(
580
+ self,
581
+ input_ids: torch.Tensor,
582
+ positions: torch.Tensor,
583
+ forward_batch: ForwardBatch,
584
+ ):
585
+ """Run forward pass for Qwen2-VL.
586
+
587
+ Args:
588
+ input_ids: Flattened (concatenated) input_ids corresponding to a
589
+ batch.
590
+ positions: Flattened (concatenated) position ids corresponding to a
591
+ batch.
592
+ **NOTE**: If mrope is enabled (default setting for Qwen2-VL
593
+ opensource models), the shape will be `(3, seq_len)`,
594
+ otherwise it will be `(seq_len,).
595
+ (Use input_metadata.mrope_positions to replace it)
596
+ pixel_values: Pixel values to be fed to a model.
597
+ `None` if no images are passed.
598
+ image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
599
+ `None` if no images are passed.
600
+ """
601
+ image_inputs = None
602
+ if forward_batch.image_inputs is not None:
603
+ image_inputs = [
604
+ img for img in forward_batch.image_inputs if img is not None
605
+ ]
606
+
607
+ positions = forward_batch.mrope_positions
608
+ if (
609
+ forward_batch.forward_mode.is_decode()
610
+ or image_inputs is None
611
+ or len(image_inputs) == 0
612
+ ):
613
+ inputs_embeds = self.model.embed_tokens(input_ids)
614
+ else:
615
+ if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
616
+ assert positions.ndim == 2 and positions.size(0) == 3, (
617
+ "multimodal section rotary embedding requires "
618
+ f"(3, seq_len) positions, but got {positions.size()}"
619
+ )
620
+
621
+ inputs_embeds = self.model.embed_tokens(input_ids)
622
+ extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
623
+ prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
624
+ for i, image in enumerate(forward_batch.image_inputs):
625
+ if image == None:
626
+ continue
627
+ start_idx = extend_start_loc_cpu[i]
628
+ prefix_len = prefix_lens_cpu[i]
629
+
630
+ pixel_values = torch.tensor(image.pixel_values, device="cuda")
631
+ image_grid_thws = torch.tensor(
632
+ np.array(image.image_grid_thws), device="cuda"
633
+ )
634
+ image_offsets = image.image_offsets
635
+ image_input = Qwen2VLImageInputs(
636
+ pixel_values=pixel_values, image_grid_thw=image_grid_thws
637
+ )
638
+ image_embeds = self._process_image_input(image_input)
639
+
640
+ image_embeds_offset = 0
641
+ for idx, image_offset in enumerate(image_offsets):
642
+ if image_offset < prefix_len:
643
+ continue
644
+ num_image_tokens = self.calculate_num_image_tokens(
645
+ image_grid_thws[idx]
646
+ )
647
+ left_idx = start_idx + (image_offset - prefix_len)
648
+ right_idx = (
649
+ start_idx + (image_offset - prefix_len) + num_image_tokens
650
+ )
651
+ inputs_embeds[left_idx:right_idx] = image_embeds[
652
+ image_embeds_offset : image_embeds_offset + num_image_tokens
653
+ ]
654
+ image_embeds_offset += num_image_tokens
655
+
656
+ input_ids = None
657
+
658
+ hidden_states = self.model(
659
+ input_ids=input_ids,
660
+ positions=positions,
661
+ forward_batch=forward_batch,
662
+ input_embeds=inputs_embeds,
663
+ )
664
+ return self.logits_processor(
665
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
666
+ )
667
+
668
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
669
+ stacked_params_mapping = [
670
+ # (param_name, shard_name, shard_id)
671
+ ("qkv_proj", "q_proj", "q"),
672
+ ("qkv_proj", "k_proj", "k"),
673
+ ("qkv_proj", "v_proj", "v"),
674
+ ("gate_up_proj", "up_proj", 1),
675
+ ("gate_up_proj", "gate_proj", 0),
676
+ ]
677
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
678
+ for name, loaded_weight in weights:
679
+ if "rotary_emb.inv_freq" in name:
680
+ continue
681
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
682
+ continue
683
+ for param_name, weight_name, shard_id in stacked_params_mapping:
684
+ if weight_name not in name:
685
+ continue
686
+ name = name.replace(weight_name, param_name)
687
+ # Skip loading extra bias for GPTQ models.
688
+ if name.endswith(".bias") and name not in params_dict:
689
+ continue
690
+ param = params_dict[name]
691
+ weight_loader = param.weight_loader
692
+ weight_loader(param, loaded_weight, shard_id)
693
+ break
694
+ else:
695
+ if "visual" in name and "qkv.weight" in name:
696
+ visual_num_heads = self.config.vision_config.num_heads
697
+ visual_embed_dim = self.config.vision_config.embed_dim
698
+ head_size = visual_embed_dim // visual_num_heads
699
+ loaded_weight = loaded_weight.view(
700
+ 3, visual_num_heads, head_size, visual_embed_dim
701
+ )
702
+ loaded_weight = loaded_weight.transpose(0, 1)
703
+ loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
704
+ elif "visual" in name and "qkv.bias" in name:
705
+ visual_num_heads = self.config.vision_config.num_heads
706
+ visual_embed_dim = self.config.vision_config.embed_dim
707
+ head_size = visual_embed_dim // visual_num_heads
708
+ loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
709
+ loaded_weight = loaded_weight.transpose(0, 1)
710
+ loaded_weight = loaded_weight.reshape(-1)
711
+ try:
712
+ # Skip loading extra bias for GPTQ models.
713
+ if name.endswith(".bias") and name not in params_dict:
714
+ continue
715
+ param = params_dict[name]
716
+ except KeyError:
717
+ print(params_dict.keys())
718
+ raise
719
+
720
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
721
+ weight_loader(param, loaded_weight)
722
+
723
+
724
+ EntryClass = Qwen2VLForConditionalGeneration