sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
sglang/srt/models/granite.py
CHANGED
@@ -363,31 +363,6 @@ class GraniteForCausalLM(nn.Module):
|
|
363
363
|
else:
|
364
364
|
return self.pooler(hidden_states, forward_batch)
|
365
365
|
|
366
|
-
def get_hidden_dim(self, module_name):
|
367
|
-
# return input_dim, output_dim
|
368
|
-
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
369
|
-
return self.config.hidden_size, self.config.hidden_size
|
370
|
-
elif module_name in ["kv_proj"]:
|
371
|
-
return self.config.hidden_size, self.config.hidden_size // (
|
372
|
-
self.config.num_attention_heads // self.config.num_key_value_heads
|
373
|
-
)
|
374
|
-
elif module_name == "gate_up_proj":
|
375
|
-
return self.config.hidden_size, self.config.intermediate_size
|
376
|
-
elif module_name == "down_proj":
|
377
|
-
return self.config.intermediate_size, self.config.hidden_size
|
378
|
-
else:
|
379
|
-
raise NotImplementedError()
|
380
|
-
|
381
|
-
def get_module_name(self, name):
|
382
|
-
params_mapping = {
|
383
|
-
"q_proj": "qkv_proj",
|
384
|
-
"k_proj": "qkv_proj",
|
385
|
-
"v_proj": "qkv_proj",
|
386
|
-
"gate_proj": "gate_up_proj",
|
387
|
-
"up_proj": "gate_up_proj",
|
388
|
-
}
|
389
|
-
return params_mapping.get(name, name)
|
390
|
-
|
391
366
|
def get_module_name_from_weight_name(self, name):
|
392
367
|
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
393
368
|
if weight_name in name:
|
sglang/srt/models/llama.py
CHANGED
@@ -91,10 +91,18 @@ class LlamaMLP(nn.Module):
|
|
91
91
|
)
|
92
92
|
self.act_fn = SiluAndMul()
|
93
93
|
|
94
|
-
def forward(
|
94
|
+
def forward(
|
95
|
+
self,
|
96
|
+
x,
|
97
|
+
forward_batch=None,
|
98
|
+
use_reduce_scatter: bool = False,
|
99
|
+
):
|
95
100
|
gate_up, _ = self.gate_up_proj(x)
|
96
101
|
x = self.act_fn(gate_up)
|
97
|
-
x, _ = self.down_proj(
|
102
|
+
x, _ = self.down_proj(
|
103
|
+
x,
|
104
|
+
skip_all_reduce=use_reduce_scatter,
|
105
|
+
)
|
98
106
|
return x
|
99
107
|
|
100
108
|
|
@@ -532,31 +540,6 @@ class LlamaForCausalLM(nn.Module):
|
|
532
540
|
def get_input_embeddings(self) -> nn.Embedding:
|
533
541
|
return self.model.embed_tokens
|
534
542
|
|
535
|
-
def get_hidden_dim(self, module_name):
|
536
|
-
# return input_dim, output_dim
|
537
|
-
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
538
|
-
return self.config.hidden_size, self.config.hidden_size
|
539
|
-
elif module_name in ["kv_proj"]:
|
540
|
-
return self.config.hidden_size, self.config.hidden_size // (
|
541
|
-
self.config.num_attention_heads // self.config.num_key_value_heads
|
542
|
-
)
|
543
|
-
elif module_name == "gate_up_proj":
|
544
|
-
return self.config.hidden_size, self.config.intermediate_size
|
545
|
-
elif module_name == "down_proj":
|
546
|
-
return self.config.intermediate_size, self.config.hidden_size
|
547
|
-
else:
|
548
|
-
raise NotImplementedError()
|
549
|
-
|
550
|
-
def get_module_name(self, name):
|
551
|
-
params_mapping = {
|
552
|
-
"q_proj": "qkv_proj",
|
553
|
-
"k_proj": "qkv_proj",
|
554
|
-
"v_proj": "qkv_proj",
|
555
|
-
"gate_proj": "gate_up_proj",
|
556
|
-
"up_proj": "gate_up_proj",
|
557
|
-
}
|
558
|
-
return params_mapping.get(name, name)
|
559
|
-
|
560
543
|
def get_module_name_from_weight_name(self, name):
|
561
544
|
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
562
545
|
if weight_name in name:
|
sglang/srt/models/llama4.py
CHANGED
@@ -32,6 +32,7 @@ from sglang.srt.layers.dp_attention import (
|
|
32
32
|
get_attention_tp_rank,
|
33
33
|
get_attention_tp_size,
|
34
34
|
get_local_attention_dp_size,
|
35
|
+
is_dp_attention_enabled,
|
35
36
|
)
|
36
37
|
from sglang.srt.layers.layernorm import RMSNorm
|
37
38
|
from sglang.srt.layers.linear import (
|
@@ -45,7 +46,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
45
46
|
from sglang.srt.layers.radix_attention import RadixAttention
|
46
47
|
from sglang.srt.layers.rotary_embedding import get_rope
|
47
48
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
48
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
49
|
from sglang.srt.model_executor.forward_batch_info import (
|
50
50
|
ForwardBatch,
|
51
51
|
ForwardMode,
|
@@ -131,14 +131,19 @@ class Llama4MoE(nn.Module):
|
|
131
131
|
reduce_results=False, # We need to do scatter before reduce
|
132
132
|
)
|
133
133
|
|
134
|
-
def forward(
|
134
|
+
def forward(
|
135
|
+
self,
|
136
|
+
hidden_states,
|
137
|
+
forward_batch: ForwardBatch,
|
138
|
+
use_reduce_scatter: bool = False,
|
139
|
+
):
|
135
140
|
shared_out, routed_out = self._forward_core(
|
136
141
|
hidden_states, forward_batch.forward_mode
|
137
142
|
)
|
138
143
|
|
139
144
|
out_aD = routed_out + shared_out
|
140
145
|
|
141
|
-
if self.tp_size > 1:
|
146
|
+
if self.tp_size > 1 and not use_reduce_scatter:
|
142
147
|
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
143
148
|
|
144
149
|
return out_aD
|
@@ -204,7 +209,7 @@ class Llama4Attention(nn.Module):
|
|
204
209
|
super().__init__()
|
205
210
|
self.layer_id = layer_id
|
206
211
|
self.hidden_size = hidden_size
|
207
|
-
self.use_rope =
|
212
|
+
self.use_rope = (layer_id + 1) % 4 != 0
|
208
213
|
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
209
214
|
|
210
215
|
attn_tp_rank = get_attention_tp_rank()
|
@@ -412,6 +417,7 @@ class Llama4DecoderLayer(nn.Module):
|
|
412
417
|
layer_scatter_modes=self.layer_scatter_modes,
|
413
418
|
input_layernorm=self.input_layernorm,
|
414
419
|
post_attention_layernorm=self.post_attention_layernorm,
|
420
|
+
allow_reduce_scatter=True,
|
415
421
|
)
|
416
422
|
|
417
423
|
def _is_moe_layer(self, layer_id: int) -> bool:
|
@@ -441,8 +447,15 @@ class Llama4DecoderLayer(nn.Module):
|
|
441
447
|
hidden_states, residual, forward_batch
|
442
448
|
)
|
443
449
|
|
450
|
+
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
451
|
+
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
452
|
+
forward_batch
|
453
|
+
)
|
454
|
+
|
444
455
|
# Fully Connected
|
445
|
-
hidden_states = self.feed_forward(
|
456
|
+
hidden_states = self.feed_forward(
|
457
|
+
hidden_states, forward_batch, use_reduce_scatter
|
458
|
+
)
|
446
459
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
447
460
|
hidden_states, residual, forward_batch
|
448
461
|
)
|
@@ -466,7 +479,7 @@ class Llama4Model(nn.Module):
|
|
466
479
|
config.hidden_size,
|
467
480
|
quant_config=quant_config,
|
468
481
|
prefix=add_prefix("embed_tokens", prefix),
|
469
|
-
enable_tp=not
|
482
|
+
enable_tp=not is_dp_attention_enabled(),
|
470
483
|
)
|
471
484
|
self.layers = make_layers(
|
472
485
|
config.num_hidden_layers,
|
sglang/srt/models/qwen2.py
CHANGED
@@ -27,6 +27,7 @@ from sglang.srt.distributed import (
|
|
27
27
|
get_tensor_model_parallel_world_size,
|
28
28
|
)
|
29
29
|
from sglang.srt.layers.activation import SiluAndMul
|
30
|
+
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
30
31
|
from sglang.srt.layers.layernorm import RMSNorm
|
31
32
|
from sglang.srt.layers.linear import (
|
32
33
|
MergedColumnParallelLinear,
|
@@ -43,7 +44,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
43
44
|
ParallelLMHead,
|
44
45
|
VocabParallelEmbedding,
|
45
46
|
)
|
46
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
47
47
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
48
48
|
from sglang.srt.model_loader.weight_utils import (
|
49
49
|
default_weight_loader,
|
@@ -273,7 +273,7 @@ class Qwen2Model(nn.Module):
|
|
273
273
|
config.vocab_size,
|
274
274
|
config.hidden_size,
|
275
275
|
quant_config=quant_config,
|
276
|
-
enable_tp=not
|
276
|
+
enable_tp=not is_dp_attention_enabled(),
|
277
277
|
prefix=add_prefix("embed_tokens", prefix),
|
278
278
|
)
|
279
279
|
else:
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -114,7 +114,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
114
114
|
num_heads: int,
|
115
115
|
hidden_act="silu",
|
116
116
|
norm_layer: Type[nn.Module] = None,
|
117
|
-
attn_implementation: Optional[str] =
|
117
|
+
attn_implementation: Optional[str] = None,
|
118
118
|
quant_config: Optional[QuantizationConfig] = None,
|
119
119
|
prefix: str = "",
|
120
120
|
) -> None:
|
@@ -123,7 +123,12 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
123
123
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
124
124
|
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
125
125
|
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
126
|
-
|
126
|
+
|
127
|
+
if attn_implementation is None:
|
128
|
+
softmax_in_single_precision = False
|
129
|
+
qkv_backend = None
|
130
|
+
flatten_batch = True
|
131
|
+
elif attn_implementation == "sdpa":
|
127
132
|
softmax_in_single_precision = False
|
128
133
|
qkv_backend = "sdpa"
|
129
134
|
flatten_batch = True
|
@@ -268,7 +273,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
268
273
|
num_heads=num_heads,
|
269
274
|
hidden_act=vision_config.hidden_act,
|
270
275
|
norm_layer=norm_layer,
|
271
|
-
attn_implementation="sdpa",
|
272
276
|
quant_config=quant_config,
|
273
277
|
prefix=add_prefix(f"blocks.{i}", prefix),
|
274
278
|
)
|
sglang/srt/models/qwen2_audio.py
CHANGED
@@ -52,7 +52,11 @@ from sglang.srt.managers.mm_utils import (
|
|
52
52
|
MultiModalityDataPaddingPatternMultimodalTokens,
|
53
53
|
general_mm_embed_routine,
|
54
54
|
)
|
55
|
-
from sglang.srt.managers.schedule_batch import
|
55
|
+
from sglang.srt.managers.schedule_batch import (
|
56
|
+
Modality,
|
57
|
+
MultimodalDataItem,
|
58
|
+
MultimodalInputs,
|
59
|
+
)
|
56
60
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
57
61
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
58
62
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
@@ -106,15 +110,10 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
|
|
106
110
|
self.language_model = Qwen2ForCausalLM(
|
107
111
|
config.text_config, quant_config, prefix=add_prefix("model", prefix)
|
108
112
|
)
|
113
|
+
self.pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
109
114
|
|
110
115
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
111
|
-
|
112
|
-
audio_token_id: int = getattr(
|
113
|
-
mm_inputs, "audio_token_id", mm_inputs.im_token_id
|
114
|
-
)
|
115
|
-
|
116
|
-
pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id])
|
117
|
-
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
116
|
+
return self.pattern.pad_input_tokens(input_ids, mm_inputs)
|
118
117
|
|
119
118
|
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
120
119
|
# Extract audio features from input items
|
@@ -143,7 +142,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
|
|
143
142
|
input_ids=input_ids,
|
144
143
|
forward_batch=forward_batch,
|
145
144
|
language_model=self.language_model,
|
146
|
-
|
145
|
+
data_embedding_funcs={
|
146
|
+
Modality.AUDIO: self.get_audio_feature,
|
147
|
+
},
|
147
148
|
positions=positions,
|
148
149
|
)
|
149
150
|
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.layers.dp_attention import (
|
|
46
46
|
get_attention_tp_rank,
|
47
47
|
get_attention_tp_size,
|
48
48
|
get_local_attention_dp_size,
|
49
|
+
is_dp_attention_enabled,
|
49
50
|
)
|
50
51
|
from sglang.srt.layers.layernorm import RMSNorm
|
51
52
|
from sglang.srt.layers.linear import (
|
@@ -107,10 +108,14 @@ class Qwen2MoeMLP(nn.Module):
|
|
107
108
|
)
|
108
109
|
self.act_fn = SiluAndMul()
|
109
110
|
|
110
|
-
def forward(
|
111
|
+
def forward(
|
112
|
+
self,
|
113
|
+
x,
|
114
|
+
use_reduce_scatter: bool = False,
|
115
|
+
):
|
111
116
|
gate_up, _ = self.gate_up_proj(x)
|
112
117
|
x = self.act_fn(gate_up)
|
113
|
-
x, _ = self.down_proj(x)
|
118
|
+
x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter)
|
114
119
|
return x
|
115
120
|
|
116
121
|
|
@@ -175,7 +180,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
175
180
|
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
176
181
|
|
177
182
|
def forward(
|
178
|
-
self,
|
183
|
+
self,
|
184
|
+
hidden_states: torch.Tensor,
|
185
|
+
forward_batch: Optional[ForwardBatch] = None,
|
186
|
+
use_reduce_scatter: bool = False,
|
179
187
|
) -> torch.Tensor:
|
180
188
|
num_tokens, hidden_dim = hidden_states.shape
|
181
189
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
@@ -193,6 +201,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
193
201
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
194
202
|
if shared_output is not None:
|
195
203
|
final_hidden_states = final_hidden_states + shared_output
|
204
|
+
if self.tp_size > 1 and not use_reduce_scatter:
|
196
205
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
197
206
|
|
198
207
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
@@ -367,6 +376,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
367
376
|
layer_scatter_modes=self.layer_scatter_modes,
|
368
377
|
input_layernorm=self.input_layernorm,
|
369
378
|
post_attention_layernorm=self.post_attention_layernorm,
|
379
|
+
allow_reduce_scatter=True,
|
370
380
|
)
|
371
381
|
|
372
382
|
def forward(
|
@@ -392,7 +402,12 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
392
402
|
hidden_states, residual, forward_batch
|
393
403
|
)
|
394
404
|
|
395
|
-
|
405
|
+
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
406
|
+
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
407
|
+
forward_batch
|
408
|
+
)
|
409
|
+
|
410
|
+
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
|
396
411
|
|
397
412
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
398
413
|
hidden_states, residual, forward_batch
|
@@ -420,7 +435,7 @@ class Qwen2MoeModel(nn.Module):
|
|
420
435
|
self.embed_tokens = VocabParallelEmbedding(
|
421
436
|
config.vocab_size,
|
422
437
|
config.hidden_size,
|
423
|
-
enable_tp=not
|
438
|
+
enable_tp=not is_dp_attention_enabled(),
|
424
439
|
prefix=add_prefix("embed_tokens", prefix),
|
425
440
|
)
|
426
441
|
else:
|
sglang/srt/models/qwen3.py
CHANGED
@@ -330,30 +330,6 @@ class Qwen3ForCausalLM(nn.Module):
|
|
330
330
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
331
331
|
return self.model.get_input_embeddings(input_ids)
|
332
332
|
|
333
|
-
def get_hidden_dim(self, module_name: str) -> Tuple[int]:
|
334
|
-
# return input_dim, output_dim
|
335
|
-
if module_name in ["q_proj", "qkv_proj"]:
|
336
|
-
return (
|
337
|
-
self.config.hidden_size,
|
338
|
-
self.config.head_dim * self.config.num_attention_heads,
|
339
|
-
)
|
340
|
-
elif module_name in ["o_proj"]:
|
341
|
-
return (
|
342
|
-
self.config.head_dim * self.config.num_attention_heads,
|
343
|
-
self.config.hidden_size,
|
344
|
-
)
|
345
|
-
elif module_name in ["kv_proj"]:
|
346
|
-
return (
|
347
|
-
self.config.hidden_size,
|
348
|
-
self.config.head_dim * self.config.num_key_value_heads,
|
349
|
-
)
|
350
|
-
elif module_name == "gate_up_proj":
|
351
|
-
return self.config.hidden_size, self.config.intermediate_size
|
352
|
-
elif module_name == "down_proj":
|
353
|
-
return self.config.intermediate_size, self.config.hidden_size
|
354
|
-
else:
|
355
|
-
raise NotImplementedError()
|
356
|
-
|
357
333
|
@torch.no_grad()
|
358
334
|
def forward(
|
359
335
|
self,
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
from typing import Iterable, Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from torch import nn
|
19
|
+
from transformers import Qwen2Config # Qwen3 uses Qwen2Config
|
20
|
+
|
21
|
+
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
22
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
23
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
|
+
from sglang.srt.models.qwen3 import Qwen3ForCausalLM, Qwen3Model
|
25
|
+
from sglang.srt.utils import add_prefix
|
26
|
+
|
27
|
+
|
28
|
+
class Qwen3ForSequenceClassification(nn.Module):
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
config: Qwen2Config,
|
32
|
+
quant_config: Optional[QuantizationConfig] = None,
|
33
|
+
prefix: str = "",
|
34
|
+
) -> None:
|
35
|
+
super().__init__()
|
36
|
+
self.config = config
|
37
|
+
self.quant_config = quant_config
|
38
|
+
self.model = Qwen3Model(
|
39
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
40
|
+
)
|
41
|
+
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
42
|
+
# Use normalize=True for qwen3 embedding based on official implementation
|
43
|
+
# Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
|
44
|
+
# Official code: output = F.normalize(output, p=2, dim=1)
|
45
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
46
|
+
|
47
|
+
self.eos_token_id = config.eos_token_id
|
48
|
+
|
49
|
+
@torch.no_grad()
|
50
|
+
def forward(
|
51
|
+
self,
|
52
|
+
input_ids: torch.Tensor,
|
53
|
+
positions: torch.Tensor,
|
54
|
+
forward_batch: ForwardBatch,
|
55
|
+
input_embeds: Optional[torch.Tensor] = None,
|
56
|
+
get_embedding: bool = True,
|
57
|
+
) -> EmbeddingPoolerOutput:
|
58
|
+
assert (
|
59
|
+
get_embedding
|
60
|
+
), "Qwen3ForSequenceClassification is only used for embedding"
|
61
|
+
|
62
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
63
|
+
logits = self.score(hidden_states)
|
64
|
+
pooled_logits = self.pooler(logits, forward_batch).embeddings
|
65
|
+
|
66
|
+
return EmbeddingPoolerOutput(pooled_logits)
|
67
|
+
|
68
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
69
|
+
# Filter out lm_head weights of Qwen3ForCausalLM
|
70
|
+
filtered_weights = [
|
71
|
+
(name, w) for name, w in weights if not name.startswith("lm_head")
|
72
|
+
]
|
73
|
+
return Qwen3ForCausalLM.load_weights(self, filtered_weights)
|
74
|
+
|
75
|
+
|
76
|
+
EntryClass = [
|
77
|
+
Qwen3ForSequenceClassification,
|
78
|
+
]
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -144,11 +144,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
144
144
|
self.top_k = config.num_experts_per_tok
|
145
145
|
|
146
146
|
def forward(
|
147
|
-
self,
|
147
|
+
self,
|
148
|
+
hidden_states: torch.Tensor,
|
149
|
+
forward_batch: Optional[ForwardBatch] = None,
|
150
|
+
use_reduce_scatter: bool = False,
|
148
151
|
) -> torch.Tensor:
|
149
152
|
|
150
153
|
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
151
|
-
return self.forward_normal(hidden_states)
|
154
|
+
return self.forward_normal(hidden_states, use_reduce_scatter)
|
152
155
|
else:
|
153
156
|
return self.forward_deepep(hidden_states, forward_batch)
|
154
157
|
|
@@ -159,7 +162,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
159
162
|
if name not in ["correction_bias"]
|
160
163
|
]
|
161
164
|
|
162
|
-
def forward_normal(
|
165
|
+
def forward_normal(
|
166
|
+
self,
|
167
|
+
hidden_states: torch.Tensor,
|
168
|
+
use_reduce_scatter: bool = False,
|
169
|
+
) -> torch.Tensor:
|
163
170
|
num_tokens, hidden_dim = hidden_states.shape
|
164
171
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
165
172
|
|
@@ -167,7 +174,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
167
174
|
router_logits, _ = self.gate(hidden_states)
|
168
175
|
topk_output = self.topk(hidden_states, router_logits)
|
169
176
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
170
|
-
if self.tp_size > 1:
|
177
|
+
if self.tp_size > 1 and not use_reduce_scatter:
|
171
178
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
172
179
|
|
173
180
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
@@ -521,6 +528,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
521
528
|
layer_scatter_modes=self.layer_scatter_modes,
|
522
529
|
input_layernorm=self.input_layernorm,
|
523
530
|
post_attention_layernorm=self.post_attention_layernorm,
|
531
|
+
allow_reduce_scatter=True,
|
524
532
|
)
|
525
533
|
|
526
534
|
def forward(
|
@@ -546,7 +554,12 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
546
554
|
hidden_states, residual, forward_batch
|
547
555
|
)
|
548
556
|
|
549
|
-
|
557
|
+
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
558
|
+
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
559
|
+
forward_batch
|
560
|
+
)
|
561
|
+
|
562
|
+
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
|
550
563
|
|
551
564
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
552
565
|
hidden_states, residual, forward_batch
|
sglang/srt/models/registry.py
CHANGED
@@ -83,7 +83,7 @@ def import_model_classes():
|
|
83
83
|
try:
|
84
84
|
module = importlib.import_module(name)
|
85
85
|
except Exception as e:
|
86
|
-
logger.warning(f"Ignore import error when loading {name}
|
86
|
+
logger.warning(f"Ignore import error when loading {name}: {e}")
|
87
87
|
continue
|
88
88
|
if hasattr(module, "EntryClass"):
|
89
89
|
entry = module.EntryClass
|
sglang/srt/models/step3_vl.py
CHANGED
@@ -25,7 +25,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
|
25
25
|
from sglang.srt.layers.activation import SiluAndMul
|
26
26
|
from sglang.srt.layers.attention.vision import VisionAttention
|
27
27
|
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
28
|
-
from sglang.srt.layers.dp_attention import
|
28
|
+
from sglang.srt.layers.dp_attention import (
|
29
|
+
get_attention_tp_rank,
|
30
|
+
get_attention_tp_size,
|
31
|
+
is_dp_attention_enabled,
|
32
|
+
)
|
29
33
|
from sglang.srt.layers.layernorm import RMSNorm
|
30
34
|
from sglang.srt.layers.linear import (
|
31
35
|
ColumnParallelLinear,
|
@@ -437,7 +441,7 @@ class Step3TextModel(nn.Module):
|
|
437
441
|
self.embed_tokens = VocabParallelEmbedding(
|
438
442
|
config.vocab_size,
|
439
443
|
config.hidden_size,
|
440
|
-
enable_tp=not
|
444
|
+
enable_tp=not is_dp_attention_enabled(),
|
441
445
|
prefix=add_prefix("embed_tokens", prefix),
|
442
446
|
)
|
443
447
|
|
@@ -416,30 +416,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
416
416
|
input_ids, hidden_states, self.lm_head, forward_batch
|
417
417
|
)
|
418
418
|
|
419
|
-
def get_hidden_dim(self, module_name):
|
420
|
-
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
421
|
-
return self.config.hidden_size, self.config.hidden_size
|
422
|
-
elif module_name in ["kv_proj"]:
|
423
|
-
return self.config.hidden_size, self.config.hidden_size // (
|
424
|
-
self.config.num_attention_heads // self.config.num_key_value_heads
|
425
|
-
)
|
426
|
-
elif module_name == "gate_up_proj":
|
427
|
-
return self.config.hidden_size, self.config.intermediate_size
|
428
|
-
elif module_name == "down_proj":
|
429
|
-
return self.config.intermediate_size, self.config.hidden_size
|
430
|
-
else:
|
431
|
-
raise NotImplementedError()
|
432
|
-
|
433
|
-
def get_module_name(self, name):
|
434
|
-
params_mapping = {
|
435
|
-
"q_proj": "qkv_proj",
|
436
|
-
"k_proj": "qkv_proj",
|
437
|
-
"v_proj": "qkv_proj",
|
438
|
-
"gate_proj": "gate_up_proj",
|
439
|
-
"up_proj": "gate_up_proj",
|
440
|
-
}
|
441
|
-
return params_mapping.get(name, name)
|
442
|
-
|
443
419
|
def get_module_name_from_weight_name(self, name):
|
444
420
|
stacked_params_mapping = [
|
445
421
|
# (param_name, shard_name, shard_id, num_shard)
|
@@ -22,13 +22,19 @@ class BaseMultiModalProcessorOutput:
|
|
22
22
|
input_text: str
|
23
23
|
|
24
24
|
# frames loaded from image, in given order
|
25
|
-
images: Optional[list[Union[Image.Image, dict]]] =
|
25
|
+
images: Optional[list[Union[Image.Image, dict]]] = dataclasses.field(
|
26
|
+
default_factory=list
|
27
|
+
)
|
26
28
|
|
27
29
|
# videos
|
28
|
-
videos: Optional[list[Union[torch.Tensor, dict]]] =
|
30
|
+
videos: Optional[list[Union[torch.Tensor, dict]]] = dataclasses.field(
|
31
|
+
default_factory=list
|
32
|
+
)
|
29
33
|
|
30
34
|
# audios
|
31
|
-
audios: Optional[list[Union[np.ndarray, dict]]] =
|
35
|
+
audios: Optional[list[Union[np.ndarray, dict]]] = dataclasses.field(
|
36
|
+
default_factory=list
|
37
|
+
)
|
32
38
|
|
33
39
|
def organize_results(self) -> List[Tuple[Modality, Any]]:
|
34
40
|
"""
|
@@ -202,7 +208,7 @@ class BaseMultimodalProcessor(ABC):
|
|
202
208
|
|
203
209
|
def process_mm_data(
|
204
210
|
self, input_text, images=None, videos=None, audios=None, **kwargs
|
205
|
-
):
|
211
|
+
) -> dict:
|
206
212
|
"""
|
207
213
|
process multimodal data with transformers AutoProcessor
|
208
214
|
"""
|
@@ -211,10 +217,14 @@ class BaseMultimodalProcessor(ABC):
|
|
211
217
|
if videos:
|
212
218
|
kwargs["videos"] = videos
|
213
219
|
if audios:
|
214
|
-
|
215
|
-
|
220
|
+
if self.arch in {
|
221
|
+
"Gemma3nForConditionalGeneration",
|
222
|
+
"Qwen2AudioForConditionalGeneration",
|
223
|
+
}:
|
216
224
|
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
|
217
225
|
kwargs["audio"] = audios
|
226
|
+
else:
|
227
|
+
kwargs["audios"] = audios
|
218
228
|
|
219
229
|
processor = self._processor
|
220
230
|
if (
|
@@ -601,12 +611,6 @@ class BaseMultimodalProcessor(ABC):
|
|
601
611
|
all_collected_items: list[MultimodalDataItem] = []
|
602
612
|
input_ids = None
|
603
613
|
|
604
|
-
# Handle dict items (already processed)
|
605
|
-
for dict_item in dict_items:
|
606
|
-
all_collected_items.extend(
|
607
|
-
self.collect_mm_items_from_processor_output(dict_item)
|
608
|
-
)
|
609
|
-
|
610
614
|
# Handle raw items (need processing)
|
611
615
|
if raw_images or raw_audios or raw_videos:
|
612
616
|
collected_items, input_ids, ret = self._process_and_collect_mm_items(
|
@@ -616,10 +620,16 @@ class BaseMultimodalProcessor(ABC):
|
|
616
620
|
videos=raw_videos,
|
617
621
|
**kwargs,
|
618
622
|
)
|
619
|
-
all_collected_items
|
623
|
+
all_collected_items = collected_items
|
620
624
|
else:
|
621
625
|
ret = None
|
622
626
|
|
627
|
+
# Handle dict items (already processed)
|
628
|
+
for dict_item in dict_items:
|
629
|
+
all_collected_items.extend(
|
630
|
+
self.collect_mm_items_from_processor_output(dict_item)
|
631
|
+
)
|
632
|
+
|
623
633
|
# Fallback tokenization if no raw items were processed
|
624
634
|
if input_ids is None:
|
625
635
|
input_ids = self._processor.tokenizer(
|