sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,9 @@ from sglang.srt.layers.linear import (
60
60
  RowParallelLinear,
61
61
  )
62
62
  from sglang.srt.layers.logits_processor import LogitsProcessor
63
- from sglang.srt.layers.moe.ep_moe.layer import (
64
- DeepEPMoE,
65
- get_moe_impl_class,
66
- should_use_flashinfer_trtllm_moe,
67
- )
63
+ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
68
64
  from sglang.srt.layers.moe.topk import TopK
65
+ from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
69
66
  from sglang.srt.layers.quantization import deep_gemm_wrapper
70
67
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
71
68
  from sglang.srt.layers.quantization.fp8_kernel import (
@@ -211,13 +208,21 @@ class DeepseekV2MLP(nn.Module):
211
208
  )
212
209
  self.act_fn = SiluAndMul()
213
210
 
214
- def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
211
+ def forward(
212
+ self,
213
+ x,
214
+ forward_batch=None,
215
+ should_allreduce_fusion: bool = False,
216
+ use_reduce_scatter: bool = False,
217
+ ):
215
218
  if (self.tp_size == 1) and x.shape[0] == 0:
216
219
  return x
217
220
 
218
221
  gate_up, _ = self.gate_up_proj(x)
219
222
  x = self.act_fn(gate_up)
220
- x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
223
+ x, _ = self.down_proj(
224
+ x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
225
+ )
221
226
  return x
222
227
 
223
228
 
@@ -307,19 +312,15 @@ class DeepseekV2MoE(nn.Module):
307
312
  config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
308
313
  )
309
314
 
310
- self.topk = (
311
- TopK(
312
- top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
313
- renormalize=config.norm_topk_prob,
314
- use_grouped_topk=True,
315
- num_expert_group=config.n_group,
316
- num_fused_shared_experts=self.num_fused_shared_experts,
317
- topk_group=config.topk_group,
318
- correction_bias=self.gate.e_score_correction_bias,
319
- routed_scaling_factor=self.routed_scaling_factor,
320
- )
321
- if not should_use_flashinfer_trtllm_moe()
322
- else None
315
+ self.topk = TopK(
316
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
317
+ renormalize=config.norm_topk_prob,
318
+ use_grouped_topk=True,
319
+ num_expert_group=config.n_group,
320
+ num_fused_shared_experts=self.num_fused_shared_experts,
321
+ topk_group=config.topk_group,
322
+ correction_bias=self.gate.e_score_correction_bias,
323
+ routed_scaling_factor=self.routed_scaling_factor,
323
324
  )
324
325
 
325
326
  self.experts = get_moe_impl_class()(
@@ -447,7 +448,8 @@ class DeepseekV2MoE(nn.Module):
447
448
  self,
448
449
  hidden_states: torch.Tensor,
449
450
  forward_batch: Optional[ForwardBatch] = None,
450
- can_fuse_mlp_allreduce: bool = False,
451
+ should_allreduce_fusion: bool = False,
452
+ use_reduce_scatter: bool = False,
451
453
  ) -> torch.Tensor:
452
454
  if not self._enable_deepep_moe:
453
455
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
@@ -457,15 +459,20 @@ class DeepseekV2MoE(nn.Module):
457
459
  and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
458
460
  ):
459
461
  return self.forward_normal_dual_stream(
460
- hidden_states, can_fuse_mlp_allreduce
462
+ hidden_states, should_allreduce_fusion, use_reduce_scatter
461
463
  )
462
464
  else:
463
- return self.forward_normal(hidden_states, can_fuse_mlp_allreduce)
465
+ return self.forward_normal(
466
+ hidden_states, should_allreduce_fusion, use_reduce_scatter
467
+ )
464
468
  else:
465
469
  return self.forward_deepep(hidden_states, forward_batch)
466
470
 
467
471
  def forward_normal_dual_stream(
468
- self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
472
+ self,
473
+ hidden_states: torch.Tensor,
474
+ should_allreduce_fusion: bool = False,
475
+ use_reduce_scatter: bool = False,
469
476
  ) -> torch.Tensor:
470
477
 
471
478
  current_stream = torch.cuda.current_stream()
@@ -476,10 +483,14 @@ class DeepseekV2MoE(nn.Module):
476
483
  # router_logits: (num_tokens, n_experts)
477
484
  router_logits = self.gate(hidden_states)
478
485
  kwargs = {"hidden_states": hidden_states}
479
- if self.topk is not None:
480
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
486
+
487
+ # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
488
+ # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
489
+ if should_use_flashinfer_trtllm_moe():
490
+ kwargs["topk_output"] = (self.topk, router_logits)
481
491
  else:
482
- kwargs["router_logits"] = router_logits
492
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
493
+
483
494
  final_hidden_states = self.experts(**kwargs)
484
495
  if not _is_cuda:
485
496
  final_hidden_states *= self.routed_scaling_factor
@@ -489,26 +500,33 @@ class DeepseekV2MoE(nn.Module):
489
500
  torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
490
501
  final_hidden_states = final_hidden_states_out
491
502
  sm.tag(final_hidden_states)
492
- if self.tp_size > 1 and not can_fuse_mlp_allreduce:
503
+ if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
493
504
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
494
505
  return final_hidden_states
495
506
 
496
507
  def forward_normal(
497
- self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
508
+ self,
509
+ hidden_states: torch.Tensor,
510
+ should_allreduce_fusion: bool = False,
511
+ use_reduce_scatter: bool = False,
498
512
  ) -> torch.Tensor:
499
513
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
500
514
  self.shared_experts.gate_up_proj
501
515
  ):
502
- return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
516
+ return self.forward_cpu(hidden_states, should_allreduce_fusion)
503
517
 
504
518
  shared_output = self._forward_shared_experts(hidden_states)
505
519
  # router_logits: (num_tokens, n_experts)
506
520
  router_logits = self.gate(hidden_states)
507
521
  kwargs = {"hidden_states": hidden_states}
508
- if self.topk is not None:
509
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
522
+
523
+ # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
524
+ # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
525
+ if should_use_flashinfer_trtllm_moe():
526
+ kwargs["topk_output"] = (self.topk, router_logits)
510
527
  else:
511
- kwargs["router_logits"] = router_logits
528
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
529
+
512
530
  final_hidden_states = self.experts(**kwargs)
513
531
  if not _is_cuda and not _use_aiter:
514
532
  # fused in biased_grouped_topk so we can skip here
@@ -519,12 +537,14 @@ class DeepseekV2MoE(nn.Module):
519
537
  torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
520
538
  final_hidden_states = final_hidden_states_out
521
539
  sm.tag(final_hidden_states)
522
- if self.tp_size > 1 and not can_fuse_mlp_allreduce:
540
+ if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
523
541
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
524
542
  return final_hidden_states
525
543
 
526
544
  def forward_cpu(
527
- self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
545
+ self,
546
+ hidden_states: torch.Tensor,
547
+ should_allreduce_fusion: bool = False,
528
548
  ) -> torch.Tensor:
529
549
  # router_logits: (num_tokens, n_experts)
530
550
  router_logits = self.gate(hidden_states)
@@ -575,7 +595,7 @@ class DeepseekV2MoE(nn.Module):
575
595
  None, # a2_scale
576
596
  True, # is_vnni
577
597
  )
578
- if self.tp_size > 1 and not can_fuse_mlp_allreduce:
598
+ if self.tp_size > 1 and not should_allreduce_fusion:
579
599
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
580
600
  return final_hidden_states
581
601
 
@@ -1176,6 +1196,16 @@ class DeepseekV2AttentionMLA(nn.Module):
1176
1196
  output, _ = self.o_proj(attn_output)
1177
1197
  return output
1178
1198
 
1199
+ def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
1200
+ """
1201
+ Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
1202
+ """
1203
+ return (
1204
+ self.current_attention_backend == "trtllm_mla"
1205
+ and forward_batch.forward_mode.is_decode_or_idle()
1206
+ and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
1207
+ )
1208
+
1179
1209
  def forward_absorb_prepare(
1180
1210
  self,
1181
1211
  positions: torch.Tensor,
@@ -1255,7 +1285,9 @@ class DeepseekV2AttentionMLA(nn.Module):
1255
1285
  q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
1256
1286
 
1257
1287
  q_nope_out = q_nope_out.transpose(0, 1)
1258
- q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1288
+
1289
+ if not self._fuse_rope_for_trtllm_mla(forward_batch):
1290
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1259
1291
 
1260
1292
  return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1261
1293
 
@@ -1268,8 +1300,20 @@ class DeepseekV2AttentionMLA(nn.Module):
1268
1300
  or self.current_attention_backend == "cutlass_mla"
1269
1301
  or self.current_attention_backend == "trtllm_mla"
1270
1302
  ):
1303
+ extra_args = {}
1304
+ if self._fuse_rope_for_trtllm_mla(forward_batch):
1305
+ extra_args = {
1306
+ "cos_sin_cache": self.rotary_emb.cos_sin_cache,
1307
+ "is_neox": self.rotary_emb.is_neox_style,
1308
+ }
1271
1309
  attn_output = self.attn_mqa(
1272
- q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
1310
+ q_nope_out,
1311
+ k_nope,
1312
+ k_nope,
1313
+ forward_batch,
1314
+ q_rope=q_pe,
1315
+ k_rope=k_pe,
1316
+ **extra_args,
1273
1317
  )
1274
1318
  else:
1275
1319
  q = torch.cat([q_nope_out, q_pe], dim=-1)
@@ -1821,8 +1865,11 @@ class DeepseekV2DecoderLayer(nn.Module):
1821
1865
  layer_scatter_modes=self.layer_scatter_modes,
1822
1866
  input_layernorm=self.input_layernorm,
1823
1867
  post_attention_layernorm=self.post_attention_layernorm,
1868
+ allow_reduce_scatter=True,
1824
1869
  )
1825
1870
 
1871
+ self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
1872
+
1826
1873
  def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
1827
1874
  return is_nextn or (
1828
1875
  self.config.n_routed_experts is not None
@@ -1831,27 +1878,18 @@ class DeepseekV2DecoderLayer(nn.Module):
1831
1878
  )
1832
1879
 
1833
1880
  def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
1834
- """Check if MLP allreduce can be fused with next layer's add_rmsnorm"""
1881
+ """Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
1835
1882
 
1836
- if (
1837
- self.layer_id == self.config.num_hidden_layers - 1
1838
- or get_tensor_model_parallel_world_size() <= 1
1839
- ):
1840
- return False
1841
-
1842
- if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False):
1843
- return False
1844
-
1845
- if not _is_sm100_supported or not _is_flashinfer_available:
1846
- return False
1883
+ batch_size = (
1884
+ forward_batch.input_ids.shape[0]
1885
+ if hasattr(forward_batch, "input_ids")
1886
+ else 0
1887
+ )
1847
1888
 
1848
- if hasattr(forward_batch, "input_ids") and (
1849
- forward_batch.input_ids.shape[0] == 0
1850
- or forward_batch.input_ids.shape[0] > 128
1851
- ):
1889
+ if batch_size > 128:
1852
1890
  return False
1853
1891
 
1854
- return True
1892
+ return self._fuse_allreduce_lookup_table.get(batch_size, False)
1855
1893
 
1856
1894
  def forward(
1857
1895
  self,
@@ -1877,18 +1915,24 @@ class DeepseekV2DecoderLayer(nn.Module):
1877
1915
  hidden_states, residual, forward_batch
1878
1916
  )
1879
1917
 
1880
- can_fuse_mlp_allreduce = (
1918
+ should_allreduce_fusion = (
1881
1919
  self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
1882
1920
  and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
1883
1921
  and not self.is_nextn
1884
1922
  )
1885
1923
 
1886
- hidden_states = self.mlp(hidden_states, forward_batch, can_fuse_mlp_allreduce)
1924
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
1925
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
1926
+ forward_batch
1927
+ )
1928
+ hidden_states = self.mlp(
1929
+ hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
1930
+ )
1887
1931
 
1888
- if can_fuse_mlp_allreduce:
1932
+ if should_allreduce_fusion:
1889
1933
  hidden_states._sglang_needs_allreduce_fusion = True
1890
1934
 
1891
- if not can_fuse_mlp_allreduce:
1935
+ if not should_allreduce_fusion:
1892
1936
  hidden_states, residual = self.layer_communicator.postprocess_layer(
1893
1937
  hidden_states, residual, forward_batch
1894
1938
  )
@@ -1965,6 +2009,26 @@ class DeepseekV2DecoderLayer(nn.Module):
1965
2009
  )
1966
2010
  return output
1967
2011
 
2012
+ def _build_fuse_allreduce_lookup_table(self):
2013
+ static_conditions_met = (
2014
+ self.layer_id != self.config.num_hidden_layers - 1
2015
+ and get_tensor_model_parallel_world_size() > 1
2016
+ and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
2017
+ and _is_sm100_supported
2018
+ and _is_flashinfer_available
2019
+ )
2020
+
2021
+ if not static_conditions_met:
2022
+ return {}
2023
+
2024
+ lookup_table = {}
2025
+ for batch_size in range(129): # 0 to 128
2026
+ is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
2027
+ should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
2028
+ lookup_table[batch_size] = should_fuse
2029
+
2030
+ return lookup_table
2031
+
1968
2032
 
1969
2033
  class DeepseekV2Model(nn.Module):
1970
2034
  fall_back_to_pt_during_load = False
@@ -2060,6 +2124,8 @@ class DeepseekV2Model(nn.Module):
2060
2124
 
2061
2125
 
2062
2126
  class DeepseekV2ForCausalLM(nn.Module):
2127
+ # for quark model load
2128
+ packed_modules_mapping = {}
2063
2129
 
2064
2130
  def __init__(
2065
2131
  self,
@@ -2068,6 +2134,18 @@ class DeepseekV2ForCausalLM(nn.Module):
2068
2134
  prefix: str = "",
2069
2135
  ) -> None:
2070
2136
  super().__init__()
2137
+
2138
+ # for quark model load
2139
+ # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
2140
+ self.fuse_qkv_a_proj = (
2141
+ hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
2142
+ )
2143
+ if self.fuse_qkv_a_proj:
2144
+ self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
2145
+ "q_a_proj",
2146
+ "kv_a_proj_with_mqa",
2147
+ ]
2148
+
2071
2149
  self.config = config
2072
2150
  self.tp_size = get_tensor_model_parallel_world_size()
2073
2151
  self.quant_config = quant_config