sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import (
|
|
51
51
|
get_attention_tp_rank,
|
52
52
|
get_attention_tp_size,
|
53
53
|
get_local_attention_dp_size,
|
54
|
+
is_dp_attention_enabled,
|
54
55
|
)
|
55
56
|
from sglang.srt.layers.layernorm import RMSNorm
|
56
57
|
from sglang.srt.layers.linear import (
|
@@ -212,7 +213,7 @@ class DeepseekV2MLP(nn.Module):
|
|
212
213
|
self,
|
213
214
|
x,
|
214
215
|
forward_batch=None,
|
215
|
-
|
216
|
+
should_allreduce_fusion: bool = False,
|
216
217
|
use_reduce_scatter: bool = False,
|
217
218
|
):
|
218
219
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
@@ -221,7 +222,7 @@ class DeepseekV2MLP(nn.Module):
|
|
221
222
|
gate_up, _ = self.gate_up_proj(x)
|
222
223
|
x = self.act_fn(gate_up)
|
223
224
|
x, _ = self.down_proj(
|
224
|
-
x, skip_all_reduce=
|
225
|
+
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
|
225
226
|
)
|
226
227
|
return x
|
227
228
|
|
@@ -448,7 +449,7 @@ class DeepseekV2MoE(nn.Module):
|
|
448
449
|
self,
|
449
450
|
hidden_states: torch.Tensor,
|
450
451
|
forward_batch: Optional[ForwardBatch] = None,
|
451
|
-
|
452
|
+
should_allreduce_fusion: bool = False,
|
452
453
|
use_reduce_scatter: bool = False,
|
453
454
|
) -> torch.Tensor:
|
454
455
|
if not self._enable_deepep_moe:
|
@@ -459,11 +460,11 @@ class DeepseekV2MoE(nn.Module):
|
|
459
460
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
460
461
|
):
|
461
462
|
return self.forward_normal_dual_stream(
|
462
|
-
hidden_states,
|
463
|
+
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
463
464
|
)
|
464
465
|
else:
|
465
466
|
return self.forward_normal(
|
466
|
-
hidden_states,
|
467
|
+
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
467
468
|
)
|
468
469
|
else:
|
469
470
|
return self.forward_deepep(hidden_states, forward_batch)
|
@@ -471,7 +472,7 @@ class DeepseekV2MoE(nn.Module):
|
|
471
472
|
def forward_normal_dual_stream(
|
472
473
|
self,
|
473
474
|
hidden_states: torch.Tensor,
|
474
|
-
|
475
|
+
should_allreduce_fusion: bool = False,
|
475
476
|
use_reduce_scatter: bool = False,
|
476
477
|
) -> torch.Tensor:
|
477
478
|
|
@@ -500,20 +501,20 @@ class DeepseekV2MoE(nn.Module):
|
|
500
501
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
501
502
|
final_hidden_states = final_hidden_states_out
|
502
503
|
sm.tag(final_hidden_states)
|
503
|
-
if self.tp_size > 1 and not
|
504
|
+
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
|
504
505
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
505
506
|
return final_hidden_states
|
506
507
|
|
507
508
|
def forward_normal(
|
508
509
|
self,
|
509
510
|
hidden_states: torch.Tensor,
|
510
|
-
|
511
|
+
should_allreduce_fusion: bool = False,
|
511
512
|
use_reduce_scatter: bool = False,
|
512
513
|
) -> torch.Tensor:
|
513
514
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
514
515
|
self.shared_experts.gate_up_proj
|
515
516
|
):
|
516
|
-
return self.forward_cpu(hidden_states,
|
517
|
+
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
517
518
|
|
518
519
|
shared_output = self._forward_shared_experts(hidden_states)
|
519
520
|
# router_logits: (num_tokens, n_experts)
|
@@ -537,12 +538,14 @@ class DeepseekV2MoE(nn.Module):
|
|
537
538
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
538
539
|
final_hidden_states = final_hidden_states_out
|
539
540
|
sm.tag(final_hidden_states)
|
540
|
-
if self.tp_size > 1 and not
|
541
|
+
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
|
541
542
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
542
543
|
return final_hidden_states
|
543
544
|
|
544
545
|
def forward_cpu(
|
545
|
-
self,
|
546
|
+
self,
|
547
|
+
hidden_states: torch.Tensor,
|
548
|
+
should_allreduce_fusion: bool = False,
|
546
549
|
) -> torch.Tensor:
|
547
550
|
# router_logits: (num_tokens, n_experts)
|
548
551
|
router_logits = self.gate(hidden_states)
|
@@ -593,7 +596,7 @@ class DeepseekV2MoE(nn.Module):
|
|
593
596
|
None, # a2_scale
|
594
597
|
True, # is_vnni
|
595
598
|
)
|
596
|
-
if self.tp_size > 1 and not
|
599
|
+
if self.tp_size > 1 and not should_allreduce_fusion:
|
597
600
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
598
601
|
return final_hidden_states
|
599
602
|
|
@@ -1194,6 +1197,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1194
1197
|
output, _ = self.o_proj(attn_output)
|
1195
1198
|
return output
|
1196
1199
|
|
1200
|
+
def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
|
1201
|
+
"""
|
1202
|
+
Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
|
1203
|
+
"""
|
1204
|
+
return (
|
1205
|
+
self.current_attention_backend == "trtllm_mla"
|
1206
|
+
and forward_batch.forward_mode.is_decode_or_idle()
|
1207
|
+
and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
|
1208
|
+
)
|
1209
|
+
|
1197
1210
|
def forward_absorb_prepare(
|
1198
1211
|
self,
|
1199
1212
|
positions: torch.Tensor,
|
@@ -1273,7 +1286,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1273
1286
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
1274
1287
|
|
1275
1288
|
q_nope_out = q_nope_out.transpose(0, 1)
|
1276
|
-
|
1289
|
+
|
1290
|
+
if not self._fuse_rope_for_trtllm_mla(forward_batch):
|
1291
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1277
1292
|
|
1278
1293
|
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1279
1294
|
|
@@ -1286,8 +1301,20 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1286
1301
|
or self.current_attention_backend == "cutlass_mla"
|
1287
1302
|
or self.current_attention_backend == "trtllm_mla"
|
1288
1303
|
):
|
1304
|
+
extra_args = {}
|
1305
|
+
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
1306
|
+
extra_args = {
|
1307
|
+
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
|
1308
|
+
"is_neox": self.rotary_emb.is_neox_style,
|
1309
|
+
}
|
1289
1310
|
attn_output = self.attn_mqa(
|
1290
|
-
q_nope_out,
|
1311
|
+
q_nope_out,
|
1312
|
+
k_nope,
|
1313
|
+
k_nope,
|
1314
|
+
forward_batch,
|
1315
|
+
q_rope=q_pe,
|
1316
|
+
k_rope=k_pe,
|
1317
|
+
**extra_args,
|
1291
1318
|
)
|
1292
1319
|
else:
|
1293
1320
|
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
@@ -1771,7 +1798,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1771
1798
|
rope_theta = getattr(config, "rope_theta", 10000)
|
1772
1799
|
rope_scaling = getattr(config, "rope_scaling", None)
|
1773
1800
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
1774
|
-
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
1775
1801
|
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
|
1776
1802
|
self.layer_id = layer_id
|
1777
1803
|
self.is_nextn = is_nextn
|
@@ -1842,6 +1868,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1842
1868
|
allow_reduce_scatter=True,
|
1843
1869
|
)
|
1844
1870
|
|
1871
|
+
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
1872
|
+
|
1845
1873
|
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
1846
1874
|
return is_nextn or (
|
1847
1875
|
self.config.n_routed_experts is not None
|
@@ -1850,27 +1878,18 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1850
1878
|
)
|
1851
1879
|
|
1852
1880
|
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
1853
|
-
"""Check if MLP allreduce can be fused with next layer's
|
1881
|
+
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
1854
1882
|
|
1855
|
-
|
1856
|
-
|
1857
|
-
|
1858
|
-
|
1859
|
-
|
1860
|
-
|
1861
|
-
if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False):
|
1862
|
-
return False
|
1863
|
-
|
1864
|
-
if not _is_sm100_supported or not _is_flashinfer_available:
|
1865
|
-
return False
|
1883
|
+
batch_size = (
|
1884
|
+
forward_batch.input_ids.shape[0]
|
1885
|
+
if hasattr(forward_batch, "input_ids")
|
1886
|
+
else 0
|
1887
|
+
)
|
1866
1888
|
|
1867
|
-
if
|
1868
|
-
forward_batch.input_ids.shape[0] == 0
|
1869
|
-
or forward_batch.input_ids.shape[0] > 128
|
1870
|
-
):
|
1889
|
+
if batch_size > 128:
|
1871
1890
|
return False
|
1872
1891
|
|
1873
|
-
return
|
1892
|
+
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
1874
1893
|
|
1875
1894
|
def forward(
|
1876
1895
|
self,
|
@@ -1896,9 +1915,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1896
1915
|
hidden_states, residual, forward_batch
|
1897
1916
|
)
|
1898
1917
|
|
1899
|
-
|
1918
|
+
should_allreduce_fusion = (
|
1900
1919
|
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
1901
|
-
and not (
|
1920
|
+
and not (
|
1921
|
+
is_dp_attention_enabled() and self.speculative_algorithm.is_eagle()
|
1922
|
+
)
|
1902
1923
|
and not self.is_nextn
|
1903
1924
|
)
|
1904
1925
|
|
@@ -1907,13 +1928,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1907
1928
|
forward_batch
|
1908
1929
|
)
|
1909
1930
|
hidden_states = self.mlp(
|
1910
|
-
hidden_states, forward_batch,
|
1931
|
+
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
|
1911
1932
|
)
|
1912
1933
|
|
1913
|
-
if
|
1934
|
+
if should_allreduce_fusion:
|
1914
1935
|
hidden_states._sglang_needs_allreduce_fusion = True
|
1915
1936
|
|
1916
|
-
if not
|
1937
|
+
if not should_allreduce_fusion:
|
1917
1938
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
1918
1939
|
hidden_states, residual, forward_batch
|
1919
1940
|
)
|
@@ -1990,6 +2011,26 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1990
2011
|
)
|
1991
2012
|
return output
|
1992
2013
|
|
2014
|
+
def _build_fuse_allreduce_lookup_table(self):
|
2015
|
+
static_conditions_met = (
|
2016
|
+
self.layer_id != self.config.num_hidden_layers - 1
|
2017
|
+
and get_tensor_model_parallel_world_size() > 1
|
2018
|
+
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
2019
|
+
and _is_sm100_supported
|
2020
|
+
and _is_flashinfer_available
|
2021
|
+
)
|
2022
|
+
|
2023
|
+
if not static_conditions_met:
|
2024
|
+
return {}
|
2025
|
+
|
2026
|
+
lookup_table = {}
|
2027
|
+
for batch_size in range(129): # 0 to 128
|
2028
|
+
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
2029
|
+
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
2030
|
+
lookup_table[batch_size] = should_fuse
|
2031
|
+
|
2032
|
+
return lookup_table
|
2033
|
+
|
1993
2034
|
|
1994
2035
|
class DeepseekV2Model(nn.Module):
|
1995
2036
|
fall_back_to_pt_during_load = False
|
@@ -2008,7 +2049,7 @@ class DeepseekV2Model(nn.Module):
|
|
2008
2049
|
self.embed_tokens = VocabParallelEmbedding(
|
2009
2050
|
config.vocab_size,
|
2010
2051
|
config.hidden_size,
|
2011
|
-
enable_tp=not
|
2052
|
+
enable_tp=not is_dp_attention_enabled(),
|
2012
2053
|
)
|
2013
2054
|
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
2014
2055
|
self.layers = nn.ModuleList(
|
sglang/srt/models/gemma2.py
CHANGED
@@ -432,40 +432,6 @@ class Gemma2ForCausalLM(nn.Module):
|
|
432
432
|
|
433
433
|
return result
|
434
434
|
|
435
|
-
def get_hidden_dim(self, module_name):
|
436
|
-
# return input_dim, output_dim
|
437
|
-
if module_name in ["q_proj", "qkv_proj"]:
|
438
|
-
return (
|
439
|
-
self.config.hidden_size,
|
440
|
-
self.config.head_dim * self.config.num_attention_heads,
|
441
|
-
)
|
442
|
-
elif module_name in ["o_proj"]:
|
443
|
-
return (
|
444
|
-
self.config.head_dim * self.config.num_attention_heads,
|
445
|
-
self.config.hidden_size,
|
446
|
-
)
|
447
|
-
elif module_name in ["kv_proj"]:
|
448
|
-
return (
|
449
|
-
self.config.hidden_size,
|
450
|
-
self.config.head_dim * self.config.num_key_value_heads,
|
451
|
-
)
|
452
|
-
elif module_name == "gate_up_proj":
|
453
|
-
return self.config.hidden_size, self.config.intermediate_size
|
454
|
-
elif module_name == "down_proj":
|
455
|
-
return self.config.intermediate_size, self.config.hidden_size
|
456
|
-
else:
|
457
|
-
raise NotImplementedError()
|
458
|
-
|
459
|
-
def get_module_name(self, name):
|
460
|
-
params_mapping = {
|
461
|
-
"q_proj": "qkv_proj",
|
462
|
-
"k_proj": "qkv_proj",
|
463
|
-
"v_proj": "qkv_proj",
|
464
|
-
"gate_proj": "gate_up_proj",
|
465
|
-
"up_proj": "gate_up_proj",
|
466
|
-
}
|
467
|
-
return params_mapping.get(name, name)
|
468
|
-
|
469
435
|
def get_attention_sliding_window_size(self):
|
470
436
|
return get_attention_sliding_window_size(self.config)
|
471
437
|
|
sglang/srt/models/gemma3n_mm.py
CHANGED
@@ -501,27 +501,26 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|
501
501
|
|
502
502
|
def get_hidden_dim(self, module_name):
|
503
503
|
# return input_dim, output_dim
|
504
|
-
if module_name
|
504
|
+
if module_name == "qkv_proj":
|
505
505
|
return (
|
506
506
|
self.config.hidden_size,
|
507
|
-
self.config.head_dim
|
507
|
+
self.config.head_dim
|
508
|
+
* (
|
509
|
+
self.config.num_attention_heads
|
510
|
+
+ self.config.num_key_value_heads * 2
|
511
|
+
),
|
508
512
|
)
|
509
|
-
elif module_name
|
513
|
+
elif module_name == "o_proj":
|
510
514
|
return (
|
511
515
|
self.config.head_dim * self.config.num_attention_heads,
|
512
516
|
self.config.hidden_size,
|
513
517
|
)
|
514
|
-
elif module_name in ["kv_proj"]:
|
515
|
-
return (
|
516
|
-
self.config.hidden_size,
|
517
|
-
self.config.head_dim * self.config.num_key_value_heads,
|
518
|
-
)
|
519
518
|
elif module_name == "gate_up_proj":
|
520
519
|
assert len(set(self.config.intermediate_size)) == 1, (
|
521
520
|
"Currently SGLang requires uniform intermediate size for all layers. "
|
522
521
|
"Please file an issue if you need support for non-uniform intermediate sizes."
|
523
522
|
)
|
524
|
-
return self.config.hidden_size, self.config.intermediate_size[0]
|
523
|
+
return self.config.hidden_size, self.config.intermediate_size[0] * 2
|
525
524
|
elif module_name == "down_proj":
|
526
525
|
assert len(set(self.config.intermediate_size)) == 1, (
|
527
526
|
"Currently SGLang requires uniform intermediate size for all layers. "
|
sglang/srt/models/glm4.py
CHANGED
@@ -218,6 +218,12 @@ class Glm4Model(nn.Module):
|
|
218
218
|
|
219
219
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
220
220
|
|
221
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
222
|
+
return self.embed_tokens
|
223
|
+
|
224
|
+
def dtype(self) -> torch.dtype:
|
225
|
+
return next(self.parameters()).dtype
|
226
|
+
|
221
227
|
@torch.no_grad()
|
222
228
|
def forward(
|
223
229
|
self,
|
sglang/srt/models/glm4_moe.py
CHANGED
@@ -40,6 +40,7 @@ from sglang.srt.layers.dp_attention import (
|
|
40
40
|
get_attention_tp_rank,
|
41
41
|
get_attention_tp_size,
|
42
42
|
get_local_attention_dp_size,
|
43
|
+
is_dp_attention_enabled,
|
43
44
|
)
|
44
45
|
from sglang.srt.layers.layernorm import RMSNorm
|
45
46
|
from sglang.srt.layers.linear import (
|
@@ -154,13 +155,13 @@ class Glm4MoeMLP(nn.Module):
|
|
154
155
|
)
|
155
156
|
self.act_fn = SiluAndMul()
|
156
157
|
|
157
|
-
def forward(self, x, forward_batch=None,
|
158
|
+
def forward(self, x, forward_batch=None, should_allreduce_fusion=False):
|
158
159
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
159
160
|
return x
|
160
161
|
|
161
162
|
gate_up, _ = self.gate_up_proj(x)
|
162
163
|
x = self.act_fn(gate_up)
|
163
|
-
x, _ = self.down_proj(x, skip_all_reduce=
|
164
|
+
x, _ = self.down_proj(x, skip_all_reduce=should_allreduce_fusion)
|
164
165
|
return x
|
165
166
|
|
166
167
|
|
@@ -529,7 +530,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
529
530
|
def forward_normal_dual_stream(
|
530
531
|
self,
|
531
532
|
hidden_states: torch.Tensor,
|
532
|
-
|
533
|
+
should_allreduce_fusion: bool = False,
|
533
534
|
use_reduce_scatter: bool = False,
|
534
535
|
) -> torch.Tensor:
|
535
536
|
|
@@ -553,7 +554,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
553
554
|
if self.ep_size > 1:
|
554
555
|
if (
|
555
556
|
self.tp_size > 1
|
556
|
-
and not
|
557
|
+
and not should_allreduce_fusion
|
557
558
|
and not use_reduce_scatter
|
558
559
|
):
|
559
560
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
@@ -564,7 +565,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
564
565
|
final_hidden_states += shared_output
|
565
566
|
if (
|
566
567
|
self.tp_size > 1
|
567
|
-
and not
|
568
|
+
and not should_allreduce_fusion
|
568
569
|
and not use_reduce_scatter
|
569
570
|
):
|
570
571
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
@@ -575,13 +576,13 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
575
576
|
def forward_normal(
|
576
577
|
self,
|
577
578
|
hidden_states: torch.Tensor,
|
578
|
-
|
579
|
+
should_allreduce_fusion: bool = False,
|
579
580
|
use_reduce_scatter: bool = False,
|
580
581
|
) -> torch.Tensor:
|
581
582
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
582
583
|
self.shared_experts.gate_up_proj
|
583
584
|
):
|
584
|
-
return self.forward_cpu(hidden_states,
|
585
|
+
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
585
586
|
|
586
587
|
shared_output = self._forward_shared_experts(hidden_states)
|
587
588
|
# router_logits: (num_tokens, n_experts)
|
@@ -596,7 +597,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
596
597
|
# fused in biased_grouped_topk so we can skip here
|
597
598
|
final_hidden_states *= self.routed_scaling_factor
|
598
599
|
if self.ep_size > 1:
|
599
|
-
if self.tp_size > 1 and not
|
600
|
+
if self.tp_size > 1 and not should_allreduce_fusion:
|
600
601
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
601
602
|
final_hidden_states
|
602
603
|
)
|
@@ -605,7 +606,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
605
606
|
else:
|
606
607
|
if shared_output is not None:
|
607
608
|
final_hidden_states += shared_output
|
608
|
-
if self.tp_size > 1 and not
|
609
|
+
if self.tp_size > 1 and not should_allreduce_fusion:
|
609
610
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
610
611
|
final_hidden_states
|
611
612
|
)
|
@@ -634,7 +635,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
|
634
635
|
)
|
635
636
|
rms_norm_eps = config.rms_norm_eps
|
636
637
|
attention_bias = config.attention_bias
|
637
|
-
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
638
638
|
self.layer_id = layer_id
|
639
639
|
self.self_attn = Glm4MoeAttention(
|
640
640
|
hidden_size=self.hidden_size,
|
@@ -744,7 +744,7 @@ class Glm4MoeModel(DeepseekV2Model):
|
|
744
744
|
self.embed_tokens = VocabParallelEmbedding(
|
745
745
|
config.vocab_size,
|
746
746
|
config.hidden_size,
|
747
|
-
enable_tp=not
|
747
|
+
enable_tp=not is_dp_attention_enabled(),
|
748
748
|
)
|
749
749
|
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
750
750
|
self.layers = nn.ModuleList(
|
@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
|
|
22
22
|
|
23
23
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
24
24
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
25
|
+
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
25
26
|
from sglang.srt.layers.layernorm import RMSNorm
|
26
27
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
28
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -56,7 +57,7 @@ class Glm4MoeModelNextN(nn.Module):
|
|
56
57
|
self.embed_tokens = VocabParallelEmbedding(
|
57
58
|
config.vocab_size,
|
58
59
|
config.hidden_size,
|
59
|
-
enable_tp=not
|
60
|
+
enable_tp=not is_dp_attention_enabled(),
|
60
61
|
prefix=add_prefix("embed_tokens", prefix),
|
61
62
|
)
|
62
63
|
|