sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import (
51
51
  get_attention_tp_rank,
52
52
  get_attention_tp_size,
53
53
  get_local_attention_dp_size,
54
+ is_dp_attention_enabled,
54
55
  )
55
56
  from sglang.srt.layers.layernorm import RMSNorm
56
57
  from sglang.srt.layers.linear import (
@@ -212,7 +213,7 @@ class DeepseekV2MLP(nn.Module):
212
213
  self,
213
214
  x,
214
215
  forward_batch=None,
215
- can_fuse_mlp_allreduce: bool = False,
216
+ should_allreduce_fusion: bool = False,
216
217
  use_reduce_scatter: bool = False,
217
218
  ):
218
219
  if (self.tp_size == 1) and x.shape[0] == 0:
@@ -221,7 +222,7 @@ class DeepseekV2MLP(nn.Module):
221
222
  gate_up, _ = self.gate_up_proj(x)
222
223
  x = self.act_fn(gate_up)
223
224
  x, _ = self.down_proj(
224
- x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
225
+ x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
225
226
  )
226
227
  return x
227
228
 
@@ -448,7 +449,7 @@ class DeepseekV2MoE(nn.Module):
448
449
  self,
449
450
  hidden_states: torch.Tensor,
450
451
  forward_batch: Optional[ForwardBatch] = None,
451
- can_fuse_mlp_allreduce: bool = False,
452
+ should_allreduce_fusion: bool = False,
452
453
  use_reduce_scatter: bool = False,
453
454
  ) -> torch.Tensor:
454
455
  if not self._enable_deepep_moe:
@@ -459,11 +460,11 @@ class DeepseekV2MoE(nn.Module):
459
460
  and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
460
461
  ):
461
462
  return self.forward_normal_dual_stream(
462
- hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
463
+ hidden_states, should_allreduce_fusion, use_reduce_scatter
463
464
  )
464
465
  else:
465
466
  return self.forward_normal(
466
- hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
467
+ hidden_states, should_allreduce_fusion, use_reduce_scatter
467
468
  )
468
469
  else:
469
470
  return self.forward_deepep(hidden_states, forward_batch)
@@ -471,7 +472,7 @@ class DeepseekV2MoE(nn.Module):
471
472
  def forward_normal_dual_stream(
472
473
  self,
473
474
  hidden_states: torch.Tensor,
474
- can_fuse_mlp_allreduce: bool = False,
475
+ should_allreduce_fusion: bool = False,
475
476
  use_reduce_scatter: bool = False,
476
477
  ) -> torch.Tensor:
477
478
 
@@ -500,20 +501,20 @@ class DeepseekV2MoE(nn.Module):
500
501
  torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
501
502
  final_hidden_states = final_hidden_states_out
502
503
  sm.tag(final_hidden_states)
503
- if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
504
+ if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
504
505
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
505
506
  return final_hidden_states
506
507
 
507
508
  def forward_normal(
508
509
  self,
509
510
  hidden_states: torch.Tensor,
510
- can_fuse_mlp_allreduce: bool = False,
511
+ should_allreduce_fusion: bool = False,
511
512
  use_reduce_scatter: bool = False,
512
513
  ) -> torch.Tensor:
513
514
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
514
515
  self.shared_experts.gate_up_proj
515
516
  ):
516
- return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
517
+ return self.forward_cpu(hidden_states, should_allreduce_fusion)
517
518
 
518
519
  shared_output = self._forward_shared_experts(hidden_states)
519
520
  # router_logits: (num_tokens, n_experts)
@@ -537,12 +538,14 @@ class DeepseekV2MoE(nn.Module):
537
538
  torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
538
539
  final_hidden_states = final_hidden_states_out
539
540
  sm.tag(final_hidden_states)
540
- if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
541
+ if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
541
542
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
542
543
  return final_hidden_states
543
544
 
544
545
  def forward_cpu(
545
- self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
546
+ self,
547
+ hidden_states: torch.Tensor,
548
+ should_allreduce_fusion: bool = False,
546
549
  ) -> torch.Tensor:
547
550
  # router_logits: (num_tokens, n_experts)
548
551
  router_logits = self.gate(hidden_states)
@@ -593,7 +596,7 @@ class DeepseekV2MoE(nn.Module):
593
596
  None, # a2_scale
594
597
  True, # is_vnni
595
598
  )
596
- if self.tp_size > 1 and not can_fuse_mlp_allreduce:
599
+ if self.tp_size > 1 and not should_allreduce_fusion:
597
600
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
598
601
  return final_hidden_states
599
602
 
@@ -1194,6 +1197,16 @@ class DeepseekV2AttentionMLA(nn.Module):
1194
1197
  output, _ = self.o_proj(attn_output)
1195
1198
  return output
1196
1199
 
1200
+ def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
1201
+ """
1202
+ Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
1203
+ """
1204
+ return (
1205
+ self.current_attention_backend == "trtllm_mla"
1206
+ and forward_batch.forward_mode.is_decode_or_idle()
1207
+ and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
1208
+ )
1209
+
1197
1210
  def forward_absorb_prepare(
1198
1211
  self,
1199
1212
  positions: torch.Tensor,
@@ -1273,7 +1286,9 @@ class DeepseekV2AttentionMLA(nn.Module):
1273
1286
  q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
1274
1287
 
1275
1288
  q_nope_out = q_nope_out.transpose(0, 1)
1276
- q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1289
+
1290
+ if not self._fuse_rope_for_trtllm_mla(forward_batch):
1291
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1277
1292
 
1278
1293
  return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1279
1294
 
@@ -1286,8 +1301,20 @@ class DeepseekV2AttentionMLA(nn.Module):
1286
1301
  or self.current_attention_backend == "cutlass_mla"
1287
1302
  or self.current_attention_backend == "trtllm_mla"
1288
1303
  ):
1304
+ extra_args = {}
1305
+ if self._fuse_rope_for_trtllm_mla(forward_batch):
1306
+ extra_args = {
1307
+ "cos_sin_cache": self.rotary_emb.cos_sin_cache,
1308
+ "is_neox": self.rotary_emb.is_neox_style,
1309
+ }
1289
1310
  attn_output = self.attn_mqa(
1290
- q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
1311
+ q_nope_out,
1312
+ k_nope,
1313
+ k_nope,
1314
+ forward_batch,
1315
+ q_rope=q_pe,
1316
+ k_rope=k_pe,
1317
+ **extra_args,
1291
1318
  )
1292
1319
  else:
1293
1320
  q = torch.cat([q_nope_out, q_pe], dim=-1)
@@ -1771,7 +1798,6 @@ class DeepseekV2DecoderLayer(nn.Module):
1771
1798
  rope_theta = getattr(config, "rope_theta", 10000)
1772
1799
  rope_scaling = getattr(config, "rope_scaling", None)
1773
1800
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1774
- self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
1775
1801
  self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
1776
1802
  self.layer_id = layer_id
1777
1803
  self.is_nextn = is_nextn
@@ -1842,6 +1868,8 @@ class DeepseekV2DecoderLayer(nn.Module):
1842
1868
  allow_reduce_scatter=True,
1843
1869
  )
1844
1870
 
1871
+ self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
1872
+
1845
1873
  def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
1846
1874
  return is_nextn or (
1847
1875
  self.config.n_routed_experts is not None
@@ -1850,27 +1878,18 @@ class DeepseekV2DecoderLayer(nn.Module):
1850
1878
  )
1851
1879
 
1852
1880
  def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
1853
- """Check if MLP allreduce can be fused with next layer's add_rmsnorm"""
1881
+ """Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
1854
1882
 
1855
- if (
1856
- self.layer_id == self.config.num_hidden_layers - 1
1857
- or get_tensor_model_parallel_world_size() <= 1
1858
- ):
1859
- return False
1860
-
1861
- if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False):
1862
- return False
1863
-
1864
- if not _is_sm100_supported or not _is_flashinfer_available:
1865
- return False
1883
+ batch_size = (
1884
+ forward_batch.input_ids.shape[0]
1885
+ if hasattr(forward_batch, "input_ids")
1886
+ else 0
1887
+ )
1866
1888
 
1867
- if hasattr(forward_batch, "input_ids") and (
1868
- forward_batch.input_ids.shape[0] == 0
1869
- or forward_batch.input_ids.shape[0] > 128
1870
- ):
1889
+ if batch_size > 128:
1871
1890
  return False
1872
1891
 
1873
- return True
1892
+ return self._fuse_allreduce_lookup_table.get(batch_size, False)
1874
1893
 
1875
1894
  def forward(
1876
1895
  self,
@@ -1896,9 +1915,11 @@ class DeepseekV2DecoderLayer(nn.Module):
1896
1915
  hidden_states, residual, forward_batch
1897
1916
  )
1898
1917
 
1899
- can_fuse_mlp_allreduce = (
1918
+ should_allreduce_fusion = (
1900
1919
  self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
1901
- and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
1920
+ and not (
1921
+ is_dp_attention_enabled() and self.speculative_algorithm.is_eagle()
1922
+ )
1902
1923
  and not self.is_nextn
1903
1924
  )
1904
1925
 
@@ -1907,13 +1928,13 @@ class DeepseekV2DecoderLayer(nn.Module):
1907
1928
  forward_batch
1908
1929
  )
1909
1930
  hidden_states = self.mlp(
1910
- hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
1931
+ hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
1911
1932
  )
1912
1933
 
1913
- if can_fuse_mlp_allreduce:
1934
+ if should_allreduce_fusion:
1914
1935
  hidden_states._sglang_needs_allreduce_fusion = True
1915
1936
 
1916
- if not can_fuse_mlp_allreduce:
1937
+ if not should_allreduce_fusion:
1917
1938
  hidden_states, residual = self.layer_communicator.postprocess_layer(
1918
1939
  hidden_states, residual, forward_batch
1919
1940
  )
@@ -1990,6 +2011,26 @@ class DeepseekV2DecoderLayer(nn.Module):
1990
2011
  )
1991
2012
  return output
1992
2013
 
2014
+ def _build_fuse_allreduce_lookup_table(self):
2015
+ static_conditions_met = (
2016
+ self.layer_id != self.config.num_hidden_layers - 1
2017
+ and get_tensor_model_parallel_world_size() > 1
2018
+ and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
2019
+ and _is_sm100_supported
2020
+ and _is_flashinfer_available
2021
+ )
2022
+
2023
+ if not static_conditions_met:
2024
+ return {}
2025
+
2026
+ lookup_table = {}
2027
+ for batch_size in range(129): # 0 to 128
2028
+ is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
2029
+ should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
2030
+ lookup_table[batch_size] = should_fuse
2031
+
2032
+ return lookup_table
2033
+
1993
2034
 
1994
2035
  class DeepseekV2Model(nn.Module):
1995
2036
  fall_back_to_pt_during_load = False
@@ -2008,7 +2049,7 @@ class DeepseekV2Model(nn.Module):
2008
2049
  self.embed_tokens = VocabParallelEmbedding(
2009
2050
  config.vocab_size,
2010
2051
  config.hidden_size,
2011
- enable_tp=not global_server_args_dict["enable_dp_attention"],
2052
+ enable_tp=not is_dp_attention_enabled(),
2012
2053
  )
2013
2054
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
2014
2055
  self.layers = nn.ModuleList(
@@ -432,40 +432,6 @@ class Gemma2ForCausalLM(nn.Module):
432
432
 
433
433
  return result
434
434
 
435
- def get_hidden_dim(self, module_name):
436
- # return input_dim, output_dim
437
- if module_name in ["q_proj", "qkv_proj"]:
438
- return (
439
- self.config.hidden_size,
440
- self.config.head_dim * self.config.num_attention_heads,
441
- )
442
- elif module_name in ["o_proj"]:
443
- return (
444
- self.config.head_dim * self.config.num_attention_heads,
445
- self.config.hidden_size,
446
- )
447
- elif module_name in ["kv_proj"]:
448
- return (
449
- self.config.hidden_size,
450
- self.config.head_dim * self.config.num_key_value_heads,
451
- )
452
- elif module_name == "gate_up_proj":
453
- return self.config.hidden_size, self.config.intermediate_size
454
- elif module_name == "down_proj":
455
- return self.config.intermediate_size, self.config.hidden_size
456
- else:
457
- raise NotImplementedError()
458
-
459
- def get_module_name(self, name):
460
- params_mapping = {
461
- "q_proj": "qkv_proj",
462
- "k_proj": "qkv_proj",
463
- "v_proj": "qkv_proj",
464
- "gate_proj": "gate_up_proj",
465
- "up_proj": "gate_up_proj",
466
- }
467
- return params_mapping.get(name, name)
468
-
469
435
  def get_attention_sliding_window_size(self):
470
436
  return get_attention_sliding_window_size(self.config)
471
437
 
@@ -501,27 +501,26 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
501
501
 
502
502
  def get_hidden_dim(self, module_name):
503
503
  # return input_dim, output_dim
504
- if module_name in ["q_proj", "qkv_proj"]:
504
+ if module_name == "qkv_proj":
505
505
  return (
506
506
  self.config.hidden_size,
507
- self.config.head_dim * self.config.num_attention_heads,
507
+ self.config.head_dim
508
+ * (
509
+ self.config.num_attention_heads
510
+ + self.config.num_key_value_heads * 2
511
+ ),
508
512
  )
509
- elif module_name in ["o_proj"]:
513
+ elif module_name == "o_proj":
510
514
  return (
511
515
  self.config.head_dim * self.config.num_attention_heads,
512
516
  self.config.hidden_size,
513
517
  )
514
- elif module_name in ["kv_proj"]:
515
- return (
516
- self.config.hidden_size,
517
- self.config.head_dim * self.config.num_key_value_heads,
518
- )
519
518
  elif module_name == "gate_up_proj":
520
519
  assert len(set(self.config.intermediate_size)) == 1, (
521
520
  "Currently SGLang requires uniform intermediate size for all layers. "
522
521
  "Please file an issue if you need support for non-uniform intermediate sizes."
523
522
  )
524
- return self.config.hidden_size, self.config.intermediate_size[0]
523
+ return self.config.hidden_size, self.config.intermediate_size[0] * 2
525
524
  elif module_name == "down_proj":
526
525
  assert len(set(self.config.intermediate_size)) == 1, (
527
526
  "Currently SGLang requires uniform intermediate size for all layers. "
sglang/srt/models/glm4.py CHANGED
@@ -218,6 +218,12 @@ class Glm4Model(nn.Module):
218
218
 
219
219
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
220
220
 
221
+ def get_input_embeddings(self) -> nn.Embedding:
222
+ return self.embed_tokens
223
+
224
+ def dtype(self) -> torch.dtype:
225
+ return next(self.parameters()).dtype
226
+
221
227
  @torch.no_grad()
222
228
  def forward(
223
229
  self,
@@ -40,6 +40,7 @@ from sglang.srt.layers.dp_attention import (
40
40
  get_attention_tp_rank,
41
41
  get_attention_tp_size,
42
42
  get_local_attention_dp_size,
43
+ is_dp_attention_enabled,
43
44
  )
44
45
  from sglang.srt.layers.layernorm import RMSNorm
45
46
  from sglang.srt.layers.linear import (
@@ -154,13 +155,13 @@ class Glm4MoeMLP(nn.Module):
154
155
  )
155
156
  self.act_fn = SiluAndMul()
156
157
 
157
- def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
158
+ def forward(self, x, forward_batch=None, should_allreduce_fusion=False):
158
159
  if (self.tp_size == 1) and x.shape[0] == 0:
159
160
  return x
160
161
 
161
162
  gate_up, _ = self.gate_up_proj(x)
162
163
  x = self.act_fn(gate_up)
163
- x, _ = self.down_proj(x, skip_all_reduce=can_fuse_mlp_allreduce)
164
+ x, _ = self.down_proj(x, skip_all_reduce=should_allreduce_fusion)
164
165
  return x
165
166
 
166
167
 
@@ -529,7 +530,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
529
530
  def forward_normal_dual_stream(
530
531
  self,
531
532
  hidden_states: torch.Tensor,
532
- can_fuse_mlp_allreduce: bool = False,
533
+ should_allreduce_fusion: bool = False,
533
534
  use_reduce_scatter: bool = False,
534
535
  ) -> torch.Tensor:
535
536
 
@@ -553,7 +554,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
553
554
  if self.ep_size > 1:
554
555
  if (
555
556
  self.tp_size > 1
556
- and not can_fuse_mlp_allreduce
557
+ and not should_allreduce_fusion
557
558
  and not use_reduce_scatter
558
559
  ):
559
560
  final_hidden_states = tensor_model_parallel_all_reduce(
@@ -564,7 +565,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
564
565
  final_hidden_states += shared_output
565
566
  if (
566
567
  self.tp_size > 1
567
- and not can_fuse_mlp_allreduce
568
+ and not should_allreduce_fusion
568
569
  and not use_reduce_scatter
569
570
  ):
570
571
  final_hidden_states = tensor_model_parallel_all_reduce(
@@ -575,13 +576,13 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
575
576
  def forward_normal(
576
577
  self,
577
578
  hidden_states: torch.Tensor,
578
- can_fuse_mlp_allreduce: bool = False,
579
+ should_allreduce_fusion: bool = False,
579
580
  use_reduce_scatter: bool = False,
580
581
  ) -> torch.Tensor:
581
582
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
582
583
  self.shared_experts.gate_up_proj
583
584
  ):
584
- return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
585
+ return self.forward_cpu(hidden_states, should_allreduce_fusion)
585
586
 
586
587
  shared_output = self._forward_shared_experts(hidden_states)
587
588
  # router_logits: (num_tokens, n_experts)
@@ -596,7 +597,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
596
597
  # fused in biased_grouped_topk so we can skip here
597
598
  final_hidden_states *= self.routed_scaling_factor
598
599
  if self.ep_size > 1:
599
- if self.tp_size > 1 and not can_fuse_mlp_allreduce:
600
+ if self.tp_size > 1 and not should_allreduce_fusion:
600
601
  final_hidden_states = tensor_model_parallel_all_reduce(
601
602
  final_hidden_states
602
603
  )
@@ -605,7 +606,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
605
606
  else:
606
607
  if shared_output is not None:
607
608
  final_hidden_states += shared_output
608
- if self.tp_size > 1 and not can_fuse_mlp_allreduce:
609
+ if self.tp_size > 1 and not should_allreduce_fusion:
609
610
  final_hidden_states = tensor_model_parallel_all_reduce(
610
611
  final_hidden_states
611
612
  )
@@ -634,7 +635,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
634
635
  )
635
636
  rms_norm_eps = config.rms_norm_eps
636
637
  attention_bias = config.attention_bias
637
- self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
638
638
  self.layer_id = layer_id
639
639
  self.self_attn = Glm4MoeAttention(
640
640
  hidden_size=self.hidden_size,
@@ -744,7 +744,7 @@ class Glm4MoeModel(DeepseekV2Model):
744
744
  self.embed_tokens = VocabParallelEmbedding(
745
745
  config.vocab_size,
746
746
  config.hidden_size,
747
- enable_tp=not global_server_args_dict["enable_dp_attention"],
747
+ enable_tp=not is_dp_attention_enabled(),
748
748
  )
749
749
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
750
750
  self.layers = nn.ModuleList(
@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
22
22
 
23
23
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
24
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
25
+ from sglang.srt.layers.dp_attention import is_dp_attention_enabled
25
26
  from sglang.srt.layers.layernorm import RMSNorm
26
27
  from sglang.srt.layers.logits_processor import LogitsProcessor
27
28
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -56,7 +57,7 @@ class Glm4MoeModelNextN(nn.Module):
56
57
  self.embed_tokens = VocabParallelEmbedding(
57
58
  config.vocab_size,
58
59
  config.hidden_size,
59
- enable_tp=not global_server_args_dict["enable_dp_attention"],
60
+ enable_tp=not is_dp_attention_enabled(),
60
61
  prefix=add_prefix("embed_tokens", prefix),
61
62
  )
62
63