sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -26,7 +26,6 @@ import logging
|
|
26
26
|
from functools import lru_cache, partial
|
27
27
|
from typing import Iterable, List, Optional, Tuple, Type
|
28
28
|
|
29
|
-
import numpy as np
|
30
29
|
import torch
|
31
30
|
import torch.nn as nn
|
32
31
|
import torch.nn.functional as F
|
@@ -34,8 +33,15 @@ from einops import rearrange
|
|
34
33
|
from transformers import AutoModel, Qwen2VLConfig
|
35
34
|
from transformers.activations import ACT2FN
|
36
35
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
36
|
+
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
37
|
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
38
|
+
Qwen2_5_VLConfig,
|
39
|
+
Qwen2_5_VLVisionConfig,
|
40
|
+
)
|
41
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
42
|
+
Qwen2_5_VLForConditionalGeneration,
|
43
|
+
)
|
37
44
|
|
38
|
-
from sglang.srt.configs import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
39
45
|
from sglang.srt.distributed import (
|
40
46
|
get_tensor_model_parallel_rank,
|
41
47
|
get_tensor_model_parallel_world_size,
|
@@ -47,14 +53,15 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
47
53
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
48
54
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
49
55
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
50
|
-
from sglang.srt.managers.
|
56
|
+
from sglang.srt.managers.mm_utils import (
|
51
57
|
MultiModalityDataPaddingPatternTokenPairs,
|
58
|
+
general_mm_embed_routine,
|
52
59
|
)
|
53
|
-
from sglang.srt.managers.schedule_batch import
|
60
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
54
61
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
55
62
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
56
63
|
from sglang.srt.models.qwen2 import Qwen2Model
|
57
|
-
from sglang.srt.models.qwen2_vl import
|
64
|
+
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
|
58
65
|
from sglang.srt.utils import add_prefix
|
59
66
|
|
60
67
|
logger = logging.getLogger(__name__)
|
@@ -125,12 +132,15 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
125
132
|
if attn_implementation == "sdpa":
|
126
133
|
use_context_forward = False
|
127
134
|
softmax_in_single_precision = False
|
135
|
+
flatten_batch = True
|
128
136
|
elif attn_implementation == "flash_attention_2":
|
129
137
|
softmax_in_single_precision = False
|
130
138
|
use_context_forward = True
|
139
|
+
flatten_batch = True
|
131
140
|
elif attn_implementation == "eager":
|
132
141
|
softmax_in_single_precision = True
|
133
142
|
use_context_forward = False
|
143
|
+
flatten_batch = True
|
134
144
|
|
135
145
|
self.attn = VisionAttention(
|
136
146
|
embed_dim=dim,
|
@@ -139,7 +149,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
139
149
|
use_qkv_parallel=False,
|
140
150
|
use_context_forward=use_context_forward,
|
141
151
|
softmax_in_single_precision=softmax_in_single_precision,
|
142
|
-
flatten_batch=
|
152
|
+
flatten_batch=flatten_batch,
|
143
153
|
quant_config=quant_config,
|
144
154
|
prefix=add_prefix("attn", prefix),
|
145
155
|
)
|
@@ -192,9 +202,10 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
|
|
192
202
|
)
|
193
203
|
|
194
204
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
205
|
+
target_dtype = self.proj.weight.dtype
|
195
206
|
L, C = x.shape
|
196
207
|
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
197
|
-
x = self.proj(x).view(L, self.embed_dim)
|
208
|
+
x = self.proj(x.to(dtype=target_dtype)).view(L, self.embed_dim)
|
198
209
|
return x
|
199
210
|
|
200
211
|
|
@@ -246,35 +257,15 @@ class Qwen2_5_VisionRotaryEmbedding(nn.Module):
|
|
246
257
|
|
247
258
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
248
259
|
super().__init__()
|
249
|
-
self.dim = dim
|
250
|
-
self.theta = theta
|
251
260
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
252
261
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
253
|
-
self._seq_len_cached = 0
|
254
|
-
self._freqs_cached = None
|
255
|
-
|
256
|
-
def update_freqs_cache(self, seqlen: int) -> None:
|
257
|
-
if seqlen > self._seq_len_cached:
|
258
|
-
seqlen *= 2
|
259
|
-
self._seq_len_cached = seqlen
|
260
|
-
self.inv_freq = 1.0 / (
|
261
|
-
self.theta
|
262
|
-
** (
|
263
|
-
torch.arange(
|
264
|
-
0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
|
265
|
-
)
|
266
|
-
/ self.dim
|
267
|
-
)
|
268
|
-
)
|
269
|
-
seq = torch.arange(
|
270
|
-
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
271
|
-
)
|
272
|
-
freqs = torch.outer(seq, self.inv_freq)
|
273
|
-
self._freqs_cached = freqs
|
274
262
|
|
275
263
|
def forward(self, seqlen: int) -> torch.Tensor:
|
276
|
-
|
277
|
-
|
264
|
+
seq = torch.arange(
|
265
|
+
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
266
|
+
)
|
267
|
+
freqs = torch.outer(seq, self.inv_freq)
|
268
|
+
return freqs
|
278
269
|
|
279
270
|
|
280
271
|
class Qwen2_5_VisionTransformer(nn.Module):
|
@@ -293,7 +284,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
293
284
|
spatial_merge_size: int = vision_config.spatial_merge_size
|
294
285
|
self.spatial_merge_size = spatial_merge_size
|
295
286
|
self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
|
296
|
-
in_chans: int = vision_config.
|
287
|
+
in_chans: int = vision_config.in_channels
|
297
288
|
hidden_size: int = vision_config.hidden_size
|
298
289
|
depth: int = vision_config.depth
|
299
290
|
num_heads: int = vision_config.num_heads
|
@@ -335,13 +326,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
335
326
|
)
|
336
327
|
|
337
328
|
def get_window_index(self, grid_thw):
|
338
|
-
window_index: list = []
|
339
329
|
cu_window_seqlens: list = [0]
|
340
330
|
window_index_id = 0
|
341
331
|
vit_merger_window_size = (
|
342
332
|
self.window_size // self.spatial_merge_size // self.patch_size
|
343
333
|
)
|
344
|
-
|
334
|
+
window_index: list = []
|
345
335
|
for grid_t, grid_h, grid_w in grid_thw:
|
346
336
|
llm_grid_h, llm_grid_w = (
|
347
337
|
grid_h // self.spatial_merge_size,
|
@@ -378,7 +368,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
378
368
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
379
369
|
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
380
370
|
window_index = torch.cat(window_index, dim=0)
|
381
|
-
|
382
371
|
return window_index, cu_window_seqlens
|
383
372
|
|
384
373
|
@property
|
@@ -391,29 +380,29 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
391
380
|
|
392
381
|
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
393
382
|
pos_ids = []
|
394
|
-
for
|
383
|
+
for i in range(grid_thw.size(0)):
|
384
|
+
t, h, w = grid_thw[i].tolist()
|
395
385
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
396
|
-
|
397
|
-
hpos_ids = (
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
self.spatial_merge_size,
|
403
|
-
)
|
404
|
-
.permute(0, 2, 1, 3)
|
405
|
-
.flatten()
|
386
|
+
|
387
|
+
hpos_ids = hpos_ids.reshape(
|
388
|
+
h // self.spatial_merge_size,
|
389
|
+
self.spatial_merge_size,
|
390
|
+
w // self.spatial_merge_size,
|
391
|
+
self.spatial_merge_size,
|
406
392
|
)
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
.
|
393
|
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
394
|
+
hpos_ids = hpos_ids.flatten()
|
395
|
+
|
396
|
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
397
|
+
wpos_ids = wpos_ids.reshape(
|
398
|
+
h // self.spatial_merge_size,
|
399
|
+
self.spatial_merge_size,
|
400
|
+
w // self.spatial_merge_size,
|
401
|
+
self.spatial_merge_size,
|
416
402
|
)
|
403
|
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
404
|
+
wpos_ids = wpos_ids.flatten()
|
405
|
+
|
417
406
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
418
407
|
pos_ids = torch.cat(pos_ids, dim=0)
|
419
408
|
max_grid_size = grid_thw[:, 1:].max()
|
@@ -437,7 +426,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
437
426
|
cu_window_seqlens = torch.tensor(
|
438
427
|
cu_window_seqlens,
|
439
428
|
device=x.device,
|
440
|
-
dtype=
|
429
|
+
dtype=torch.int32,
|
441
430
|
)
|
442
431
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
443
432
|
|
@@ -455,9 +444,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
455
444
|
position_embeddings = (emb.cos(), emb.sin())
|
456
445
|
|
457
446
|
# compute cu_seqlens
|
458
|
-
cu_seqlens = torch.
|
459
|
-
|
460
|
-
|
447
|
+
cu_seqlens = torch.cat(
|
448
|
+
[
|
449
|
+
torch.tensor([0], device=grid_thw.device),
|
450
|
+
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
|
451
|
+
]
|
452
|
+
)
|
461
453
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
462
454
|
|
463
455
|
# transformers
|
@@ -521,19 +513,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
521
513
|
self.logits_processor = LogitsProcessor(config)
|
522
514
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
523
515
|
|
524
|
-
def
|
525
|
-
processor = cached_get_processor(self.config._name_or_path)
|
526
|
-
grid_t, grid_h, grid_w = image_grid_thw
|
527
|
-
num_image_tokens = (
|
528
|
-
grid_t
|
529
|
-
* grid_h
|
530
|
-
* grid_w
|
531
|
-
// processor.image_processor.merge_size
|
532
|
-
// processor.image_processor.merge_size
|
533
|
-
)
|
534
|
-
return num_image_tokens
|
535
|
-
|
536
|
-
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
516
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
537
517
|
# Get all special token IDs
|
538
518
|
im_start_id: int = image_inputs.im_start_id
|
539
519
|
im_end_id: int = image_inputs.im_end_id
|
@@ -543,9 +523,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
543
523
|
|
544
524
|
return pattern.pad_input_tokens(input_ids, image_inputs)
|
545
525
|
|
546
|
-
def
|
547
|
-
pixel_values = image_input
|
548
|
-
image_embeds = self.visual(pixel_values, grid_thw=image_input
|
526
|
+
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
|
527
|
+
pixel_values = image_input.pixel_values.type(self.visual.dtype)
|
528
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
|
549
529
|
return image_embeds
|
550
530
|
|
551
531
|
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
@@ -555,6 +535,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
555
535
|
)
|
556
536
|
return video_embeds
|
557
537
|
|
538
|
+
def get_input_embeddings(self):
|
539
|
+
return self.model.embed_tokens
|
540
|
+
|
558
541
|
def forward(
|
559
542
|
self,
|
560
543
|
input_ids: torch.Tensor,
|
@@ -577,85 +560,25 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
577
560
|
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
578
561
|
positions = forward_batch.mrope_positions
|
579
562
|
|
580
|
-
|
581
|
-
if forward_batch.image_inputs is not None:
|
582
|
-
image_inputs = [
|
583
|
-
img for img in forward_batch.image_inputs if img is not None
|
584
|
-
]
|
585
|
-
|
586
|
-
if (
|
563
|
+
if not (
|
587
564
|
forward_batch.forward_mode.is_decode()
|
588
|
-
or
|
589
|
-
or len(image_inputs) == 0
|
565
|
+
or not forward_batch.contains_image_inputs()
|
590
566
|
):
|
591
|
-
inputs_embeds = self.model.embed_tokens(input_ids)
|
592
|
-
else:
|
593
567
|
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
594
568
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
595
569
|
"multimodal section rotary embedding requires "
|
596
570
|
f"(3, seq_len) positions, but got {positions.size()}"
|
597
571
|
)
|
598
572
|
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
606
|
-
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
607
|
-
for i, image in enumerate(forward_batch.image_inputs):
|
608
|
-
if image is None or image.pixel_values is None:
|
609
|
-
continue
|
610
|
-
start_idx = extend_start_loc_cpu[i]
|
611
|
-
prefix_len = prefix_lens_cpu[i]
|
612
|
-
|
613
|
-
pixel_values = image.pixel_values.clone().detach().requires_grad_(False)
|
614
|
-
image_grid_thws = torch.tensor(
|
615
|
-
np.array(image.image_grid_thws), device="cuda"
|
616
|
-
)
|
617
|
-
image_offsets = image.image_offsets
|
618
|
-
image_input = Qwen2VLImageInputs(
|
619
|
-
pixel_values=pixel_values, image_grid_thw=image_grid_thws
|
620
|
-
)
|
621
|
-
image_embeds = self._process_image_input(image_input)
|
622
|
-
|
623
|
-
image_embeds_offset = 0
|
624
|
-
for idx, image_offset in enumerate(image_offsets):
|
625
|
-
if image_offset < prefix_len:
|
626
|
-
continue
|
627
|
-
num_image_tokens = self.calculate_num_image_tokens(
|
628
|
-
image_grid_thws[idx]
|
629
|
-
)
|
630
|
-
|
631
|
-
left_idx = start_idx + (image_offset - prefix_len)
|
632
|
-
right_idx = left_idx + num_image_tokens
|
633
|
-
|
634
|
-
tp_size = get_tensor_model_parallel_world_size()
|
635
|
-
|
636
|
-
hidden_size = image_embeds.shape[-1]
|
637
|
-
|
638
|
-
if hidden_size % tp_size != 0:
|
639
|
-
padding_size = tp_size - (hidden_size % tp_size)
|
640
|
-
image_embeds = F.pad(image_embeds, (0, padding_size))
|
641
|
-
inputs_embeds = F.pad(inputs_embeds, (0, padding_size))
|
642
|
-
|
643
|
-
hidden_chunk_size = image_embeds.shape[-1] // tp_size
|
644
|
-
rank = get_tensor_model_parallel_rank()
|
645
|
-
start_dim = rank * hidden_chunk_size
|
646
|
-
end_dim = (rank + 1) * hidden_chunk_size
|
647
|
-
inputs_embeds[left_idx:right_idx, ..., start_dim:end_dim] = (
|
648
|
-
image_embeds[
|
649
|
-
image_embeds_offset : image_embeds_offset
|
650
|
-
+ num_image_tokens,
|
651
|
-
...,
|
652
|
-
start_dim:end_dim,
|
653
|
-
]
|
654
|
-
)
|
655
|
-
image_embeds_offset += num_image_tokens
|
573
|
+
inputs_embeds = general_mm_embed_routine(
|
574
|
+
input_ids=input_ids,
|
575
|
+
forward_batch=forward_batch,
|
576
|
+
embed_tokens=self.get_input_embeddings(),
|
577
|
+
mm_data_embedding_func=self.get_image_feature,
|
578
|
+
)
|
656
579
|
|
657
580
|
hidden_states = self.model(
|
658
|
-
input_ids=
|
581
|
+
input_ids=None,
|
659
582
|
positions=positions,
|
660
583
|
forward_batch=forward_batch,
|
661
584
|
input_embeds=inputs_embeds,
|
@@ -732,4 +655,3 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
732
655
|
|
733
656
|
|
734
657
|
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|
735
|
-
AutoModel.register(Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration)
|
@@ -0,0 +1,75 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
from typing import Iterable, Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from torch import nn
|
19
|
+
from transformers import Qwen2Config
|
20
|
+
|
21
|
+
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
22
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
23
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
|
25
|
+
from sglang.srt.utils import add_prefix
|
26
|
+
|
27
|
+
|
28
|
+
class Qwen2ForSequenceClassification(nn.Module):
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
config: Qwen2Config,
|
32
|
+
quant_config: Optional[QuantizationConfig] = None,
|
33
|
+
prefix: str = "",
|
34
|
+
) -> None:
|
35
|
+
super().__init__()
|
36
|
+
self.config = config
|
37
|
+
self.quant_config = quant_config
|
38
|
+
self.model = Qwen2Model(
|
39
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
40
|
+
)
|
41
|
+
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
42
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
|
43
|
+
|
44
|
+
self.eos_token_id = config.eos_token_id
|
45
|
+
|
46
|
+
@torch.no_grad()
|
47
|
+
def forward(
|
48
|
+
self,
|
49
|
+
input_ids: torch.Tensor,
|
50
|
+
positions: torch.Tensor,
|
51
|
+
forward_batch: ForwardBatch,
|
52
|
+
input_embeds: torch.Tensor = None,
|
53
|
+
get_embedding: bool = True,
|
54
|
+
) -> EmbeddingPoolerOutput:
|
55
|
+
assert (
|
56
|
+
get_embedding
|
57
|
+
), "Qwen2ForSequenceClassification is only used for embedding"
|
58
|
+
|
59
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
60
|
+
logits = self.score(hidden_states)
|
61
|
+
pooled_logits = self.pooler(logits, forward_batch).embeddings
|
62
|
+
|
63
|
+
return EmbeddingPoolerOutput(pooled_logits)
|
64
|
+
|
65
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
66
|
+
# Filter out lm_head weights of Qwen2ForCausalLM
|
67
|
+
filtered_weights = [
|
68
|
+
(name, w) for name, w in weights if not name.startswith("lm_head")
|
69
|
+
]
|
70
|
+
return Qwen2ForCausalLM.load_weights(self, filtered_weights)
|
71
|
+
|
72
|
+
|
73
|
+
EntryClass = [
|
74
|
+
Qwen2ForSequenceClassification,
|
75
|
+
]
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -44,10 +44,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
44
44
|
ParallelLMHead,
|
45
45
|
VocabParallelEmbedding,
|
46
46
|
)
|
47
|
+
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
47
48
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
49
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
49
50
|
from sglang.srt.utils import add_prefix
|
50
51
|
|
52
|
+
expert_distribution_recorder = ExpertDistributionRecorder()
|
53
|
+
|
51
54
|
|
52
55
|
class Qwen2MoeMLP(nn.Module):
|
53
56
|
def __init__(
|
@@ -170,6 +173,7 @@ class Qwen2MoeAttention(nn.Module):
|
|
170
173
|
rope_theta: float = 10000,
|
171
174
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
172
175
|
max_position_embeddings: int = 8192,
|
176
|
+
qkv_bias: int = True,
|
173
177
|
quant_config: Optional[QuantizationConfig] = None,
|
174
178
|
prefix: str = "",
|
175
179
|
) -> None:
|
@@ -201,7 +205,7 @@ class Qwen2MoeAttention(nn.Module):
|
|
201
205
|
self.head_dim,
|
202
206
|
self.total_num_heads,
|
203
207
|
self.total_num_kv_heads,
|
204
|
-
bias=
|
208
|
+
bias=qkv_bias,
|
205
209
|
quant_config=quant_config,
|
206
210
|
prefix=add_prefix("qkv_proj", prefix),
|
207
211
|
)
|
@@ -257,6 +261,8 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
257
261
|
rope_theta = getattr(config, "rope_theta", 10000)
|
258
262
|
rope_scaling = getattr(config, "rope_scaling", None)
|
259
263
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
264
|
+
# note: replace config.num_hidden_layers < 80 with True once its available in transformers 4.50.0
|
265
|
+
qkv_bias = getattr(config, "qkv_bias", config.num_hidden_layers < 80)
|
260
266
|
self.self_attn = Qwen2MoeAttention(
|
261
267
|
hidden_size=self.hidden_size,
|
262
268
|
num_heads=config.num_attention_heads,
|
@@ -266,6 +272,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
266
272
|
rope_scaling=rope_scaling,
|
267
273
|
max_position_embeddings=max_position_embeddings,
|
268
274
|
quant_config=quant_config,
|
275
|
+
qkv_bias=qkv_bias,
|
269
276
|
prefix=add_prefix("self_attn", prefix),
|
270
277
|
)
|
271
278
|
|
@@ -362,6 +369,7 @@ class Qwen2MoeModel(nn.Module):
|
|
362
369
|
hidden_states = input_embeds
|
363
370
|
residual = None
|
364
371
|
for i in range(len(self.layers)):
|
372
|
+
expert_distribution_recorder.set_current_layer(i)
|
365
373
|
layer = self.layers[i]
|
366
374
|
hidden_states, residual = layer(
|
367
375
|
positions, hidden_states, forward_batch, residual
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -26,7 +26,6 @@ import logging
|
|
26
26
|
from functools import lru_cache, partial
|
27
27
|
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
|
28
28
|
|
29
|
-
import numpy as np
|
30
29
|
import torch
|
31
30
|
import torch.nn as nn
|
32
31
|
import torch.nn.functional as F
|
@@ -42,10 +41,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
42
41
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
43
42
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
44
43
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
45
|
-
from sglang.srt.managers.
|
44
|
+
from sglang.srt.managers.mm_utils import (
|
46
45
|
MultiModalityDataPaddingPatternTokenPairs,
|
46
|
+
general_mm_embed_routine,
|
47
47
|
)
|
48
|
-
from sglang.srt.managers.schedule_batch import
|
48
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
49
49
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
50
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
51
51
|
from sglang.srt.models.qwen2 import Qwen2Model
|
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|
351
351
|
|
352
352
|
@property
|
353
353
|
def dtype(self) -> torch.dtype:
|
354
|
-
return self.
|
354
|
+
return next(self.parameters()).dtype
|
355
355
|
|
356
356
|
@property
|
357
357
|
def device(self) -> torch.device:
|
@@ -359,7 +359,8 @@ class Qwen2VisionTransformer(nn.Module):
|
|
359
359
|
|
360
360
|
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
361
361
|
pos_ids = []
|
362
|
-
for
|
362
|
+
for i in range(grid_thw.size(0)):
|
363
|
+
t, h, w = grid_thw[i].tolist()
|
363
364
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
364
365
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
365
366
|
hpos_ids = (
|
@@ -471,18 +472,18 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
471
472
|
|
472
473
|
# Use grid_t * grid_w * grid_h to pad tokens for each image
|
473
474
|
# add replaced padding by unique image hash
|
474
|
-
def pad_input_ids(self, input_ids: List[int],
|
475
|
+
def pad_input_ids(self, input_ids: List[int], multi_modal_inputs: MultimodalInputs):
|
475
476
|
# Get all special token IDs
|
476
|
-
im_start_id: int =
|
477
|
-
im_end_id: int =
|
477
|
+
im_start_id: int = multi_modal_inputs.im_start_id
|
478
|
+
im_end_id: int = multi_modal_inputs.im_end_id
|
478
479
|
|
479
480
|
media_token_pairs = [(im_start_id, im_end_id)]
|
480
481
|
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
481
|
-
return pattern.pad_input_tokens(input_ids,
|
482
|
+
return pattern.pad_input_tokens(input_ids, multi_modal_inputs)
|
482
483
|
|
483
|
-
def
|
484
|
-
pixel_values = image_input
|
485
|
-
image_embeds = self.visual(pixel_values, grid_thw=image_input
|
484
|
+
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
|
485
|
+
pixel_values = image_input.pixel_values.type(self.visual.dtype)
|
486
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
|
486
487
|
return image_embeds
|
487
488
|
|
488
489
|
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
@@ -492,6 +493,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
492
493
|
)
|
493
494
|
return video_embeds
|
494
495
|
|
496
|
+
def get_input_embeddings(self):
|
497
|
+
return self.model.embed_tokens
|
498
|
+
|
495
499
|
def forward(
|
496
500
|
self,
|
497
501
|
input_ids: torch.Tensor,
|
@@ -514,67 +518,25 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
514
518
|
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
515
519
|
positions = forward_batch.mrope_positions
|
516
520
|
|
517
|
-
|
518
|
-
if forward_batch.image_inputs is not None:
|
519
|
-
image_inputs = [
|
520
|
-
img for img in forward_batch.image_inputs if img is not None
|
521
|
-
]
|
522
|
-
|
523
|
-
if (
|
521
|
+
if not (
|
524
522
|
forward_batch.forward_mode.is_decode()
|
525
|
-
or
|
526
|
-
or len(image_inputs) == 0
|
523
|
+
or not forward_batch.contains_image_inputs()
|
527
524
|
):
|
528
|
-
inputs_embeds = self.model.embed_tokens(input_ids)
|
529
|
-
else:
|
530
525
|
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
531
526
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
532
527
|
"multimodal section rotary embedding requires "
|
533
528
|
f"(3, seq_len) positions, but got {positions.size()}"
|
534
529
|
)
|
535
530
|
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
543
|
-
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
544
|
-
for i, image in enumerate(forward_batch.image_inputs):
|
545
|
-
if image is None or image.pixel_values is None:
|
546
|
-
continue
|
547
|
-
start_idx = extend_start_loc_cpu[i]
|
548
|
-
prefix_len = prefix_lens_cpu[i]
|
549
|
-
pixel_values = image.pixel_values.clone()
|
550
|
-
|
551
|
-
image_grid_thws = torch.tensor(
|
552
|
-
np.array(image.image_grid_thws), device="cuda"
|
553
|
-
)
|
554
|
-
image_offsets = image.image_offsets
|
555
|
-
image_input = Qwen2VLImageInputs(
|
556
|
-
pixel_values=pixel_values, image_grid_thw=image_grid_thws
|
557
|
-
)
|
558
|
-
image_embeds = self._process_image_input(image_input)
|
559
|
-
|
560
|
-
image_embeds_offset = 0
|
561
|
-
for idx, image_offset in enumerate(image_offsets):
|
562
|
-
if image_offset < prefix_len:
|
563
|
-
continue
|
564
|
-
num_image_tokens = self.calculate_num_image_tokens(
|
565
|
-
image_grid_thws[idx]
|
566
|
-
)
|
567
|
-
|
568
|
-
left_idx = start_idx + (image_offset - prefix_len + 1)
|
569
|
-
right_idx = left_idx + num_image_tokens
|
570
|
-
inputs_embeds[left_idx:right_idx] = image_embeds[
|
571
|
-
image_embeds_offset : image_embeds_offset + num_image_tokens
|
572
|
-
]
|
573
|
-
image_embeds_offset += num_image_tokens
|
574
|
-
input_ids = None
|
531
|
+
inputs_embeds = general_mm_embed_routine(
|
532
|
+
input_ids=input_ids,
|
533
|
+
forward_batch=forward_batch,
|
534
|
+
embed_tokens=self.get_input_embeddings(),
|
535
|
+
mm_data_embedding_func=self.get_image_feature,
|
536
|
+
)
|
575
537
|
|
576
538
|
hidden_states = self.model(
|
577
|
-
input_ids=
|
539
|
+
input_ids=None,
|
578
540
|
positions=positions,
|
579
541
|
forward_batch=forward_batch,
|
580
542
|
input_embeds=inputs_embeds,
|