sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +35 -0
  5. sglang/srt/conversation.py +9 -117
  6. sglang/srt/disaggregation/base/conn.py +5 -2
  7. sglang/srt/disaggregation/decode.py +6 -1
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
  9. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  10. sglang/srt/disaggregation/prefill.py +3 -0
  11. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  12. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  13. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  14. sglang/srt/distributed/parallel_state.py +22 -9
  15. sglang/srt/entrypoints/context.py +244 -0
  16. sglang/srt/entrypoints/engine.py +8 -5
  17. sglang/srt/entrypoints/harmony_utils.py +370 -0
  18. sglang/srt/entrypoints/http_server.py +106 -15
  19. sglang/srt/entrypoints/openai/protocol.py +227 -1
  20. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  21. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  22. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  23. sglang/srt/entrypoints/tool.py +87 -0
  24. sglang/srt/eplb/expert_distribution.py +4 -2
  25. sglang/srt/eplb/expert_location.py +5 -1
  26. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  27. sglang/srt/hf_transformers_utils.py +55 -13
  28. sglang/srt/jinja_template_utils.py +8 -1
  29. sglang/srt/layers/attention/aiter_backend.py +5 -8
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  31. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  32. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  33. sglang/srt/layers/attention/triton_backend.py +85 -14
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  35. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  36. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  38. sglang/srt/layers/attention/vision.py +40 -15
  39. sglang/srt/layers/communicator.py +35 -8
  40. sglang/srt/layers/dp_attention.py +12 -0
  41. sglang/srt/layers/linear.py +9 -8
  42. sglang/srt/layers/logits_processor.py +9 -1
  43. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  44. sglang/srt/layers/moe/ep_moe/layer.py +87 -107
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
  48. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
  49. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  50. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  51. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  52. sglang/srt/layers/moe/topk.py +12 -3
  53. sglang/srt/layers/moe/utils.py +59 -0
  54. sglang/srt/layers/quantization/__init__.py +22 -0
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  56. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  57. sglang/srt/layers/quantization/fp4.py +557 -0
  58. sglang/srt/layers/quantization/fp8.py +8 -7
  59. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  60. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  61. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  62. sglang/srt/layers/quantization/mxfp4.py +651 -0
  63. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  64. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  65. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  66. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  67. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  68. sglang/srt/layers/quantization/quark/utils.py +107 -0
  69. sglang/srt/layers/quantization/unquant.py +60 -6
  70. sglang/srt/layers/quantization/w4afp8.py +1 -1
  71. sglang/srt/layers/rotary_embedding.py +225 -1
  72. sglang/srt/layers/utils.py +9 -0
  73. sglang/srt/layers/vocab_parallel_embedding.py +15 -4
  74. sglang/srt/lora/lora_manager.py +70 -14
  75. sglang/srt/lora/lora_registry.py +10 -2
  76. sglang/srt/lora/mem_pool.py +43 -5
  77. sglang/srt/managers/cache_controller.py +61 -32
  78. sglang/srt/managers/data_parallel_controller.py +52 -2
  79. sglang/srt/managers/detokenizer_manager.py +1 -1
  80. sglang/srt/managers/io_struct.py +21 -4
  81. sglang/srt/managers/mm_utils.py +5 -11
  82. sglang/srt/managers/schedule_batch.py +30 -8
  83. sglang/srt/managers/schedule_policy.py +3 -1
  84. sglang/srt/managers/scheduler.py +170 -18
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  86. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  87. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  88. sglang/srt/managers/template_manager.py +59 -22
  89. sglang/srt/managers/tokenizer_manager.py +137 -67
  90. sglang/srt/managers/tp_worker.py +3 -0
  91. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  92. sglang/srt/managers/utils.py +45 -1
  93. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  94. sglang/srt/mem_cache/hicache_storage.py +13 -21
  95. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  96. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  97. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  98. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  99. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  100. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  101. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  102. sglang/srt/model_executor/forward_batch_info.py +48 -17
  103. sglang/srt/model_executor/model_runner.py +24 -2
  104. sglang/srt/model_loader/weight_utils.py +10 -0
  105. sglang/srt/models/bailing_moe.py +425 -0
  106. sglang/srt/models/deepseek_v2.py +95 -50
  107. sglang/srt/models/ernie4.py +426 -0
  108. sglang/srt/models/ernie4_eagle.py +203 -0
  109. sglang/srt/models/gemma3n_mm.py +39 -0
  110. sglang/srt/models/glm4_moe.py +102 -27
  111. sglang/srt/models/gpt_oss.py +1134 -0
  112. sglang/srt/models/grok.py +3 -3
  113. sglang/srt/models/llama4.py +13 -2
  114. sglang/srt/models/mixtral.py +3 -3
  115. sglang/srt/models/mllama4.py +428 -19
  116. sglang/srt/models/qwen2.py +6 -0
  117. sglang/srt/models/qwen2_moe.py +7 -4
  118. sglang/srt/models/qwen3_moe.py +39 -14
  119. sglang/srt/models/step3_vl.py +10 -1
  120. sglang/srt/models/transformers.py +2 -5
  121. sglang/srt/multimodal/processors/base_processor.py +4 -3
  122. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  123. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  124. sglang/srt/operations_strategy.py +1 -1
  125. sglang/srt/reasoning_parser.py +18 -39
  126. sglang/srt/server_args.py +218 -23
  127. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  128. sglang/srt/two_batch_overlap.py +163 -9
  129. sglang/srt/utils.py +41 -26
  130. sglang/srt/weight_sync/utils.py +1 -1
  131. sglang/test/runners.py +4 -4
  132. sglang/test/test_utils.py +4 -4
  133. sglang/version.py +1 -1
  134. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
  135. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
  136. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  137. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  138. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  139. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  140. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  141. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  143. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,14 @@ from tqdm import tqdm
29
29
  from transformers import PretrainedConfig
30
30
 
31
31
  from sglang.srt.distributed import (
32
+ get_moe_expert_parallel_world_size,
32
33
  get_tensor_model_parallel_world_size,
33
34
  parallel_state,
34
35
  tensor_model_parallel_all_reduce,
35
36
  )
37
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
38
+ use_symmetric_memory,
39
+ )
36
40
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
37
41
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
38
42
  from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
@@ -56,13 +60,9 @@ from sglang.srt.layers.linear import (
56
60
  RowParallelLinear,
57
61
  )
58
62
  from sglang.srt.layers.logits_processor import LogitsProcessor
59
- from sglang.srt.layers.moe.ep_moe.layer import (
60
- DeepEPMoE,
61
- get_moe_impl_class,
62
- should_use_flashinfer_trtllm_moe,
63
- )
64
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
63
+ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
65
64
  from sglang.srt.layers.moe.topk import TopK
65
+ from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
66
66
  from sglang.srt.layers.quantization import deep_gemm_wrapper
67
67
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
68
68
  from sglang.srt.layers.quantization.fp8_kernel import (
@@ -96,7 +96,6 @@ from sglang.srt.two_batch_overlap import (
96
96
  )
97
97
  from sglang.srt.utils import (
98
98
  BumpAllocator,
99
- DeepEPMode,
100
99
  LazyValue,
101
100
  add_prefix,
102
101
  bind_or_assign,
@@ -209,13 +208,21 @@ class DeepseekV2MLP(nn.Module):
209
208
  )
210
209
  self.act_fn = SiluAndMul()
211
210
 
212
- def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
211
+ def forward(
212
+ self,
213
+ x,
214
+ forward_batch=None,
215
+ can_fuse_mlp_allreduce: bool = False,
216
+ use_reduce_scatter: bool = False,
217
+ ):
213
218
  if (self.tp_size == 1) and x.shape[0] == 0:
214
219
  return x
215
220
 
216
221
  gate_up, _ = self.gate_up_proj(x)
217
222
  x = self.act_fn(gate_up)
218
- x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
223
+ x, _ = self.down_proj(
224
+ x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
225
+ )
219
226
  return x
220
227
 
221
228
 
@@ -305,19 +312,15 @@ class DeepseekV2MoE(nn.Module):
305
312
  config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
306
313
  )
307
314
 
308
- self.topk = (
309
- TopK(
310
- top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
311
- renormalize=config.norm_topk_prob,
312
- use_grouped_topk=True,
313
- num_expert_group=config.n_group,
314
- num_fused_shared_experts=self.num_fused_shared_experts,
315
- topk_group=config.topk_group,
316
- correction_bias=self.gate.e_score_correction_bias,
317
- routed_scaling_factor=self.routed_scaling_factor,
318
- )
319
- if not should_use_flashinfer_trtllm_moe()
320
- else None
315
+ self.topk = TopK(
316
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
317
+ renormalize=config.norm_topk_prob,
318
+ use_grouped_topk=True,
319
+ num_expert_group=config.n_group,
320
+ num_fused_shared_experts=self.num_fused_shared_experts,
321
+ topk_group=config.topk_group,
322
+ correction_bias=self.gate.e_score_correction_bias,
323
+ routed_scaling_factor=self.routed_scaling_factor,
321
324
  )
322
325
 
323
326
  self.experts = get_moe_impl_class()(
@@ -333,15 +336,14 @@ class DeepseekV2MoE(nn.Module):
333
336
  routed_scaling_factor=self.routed_scaling_factor,
334
337
  prefix=add_prefix("experts", prefix),
335
338
  **(
336
- dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
337
- if global_server_args_dict["enable_deepep_moe"]
339
+ dict(deepep_mode=global_server_args_dict["deepep_mode"])
340
+ if global_server_args_dict["moe_a2a_backend"].is_deepep()
338
341
  else {}
339
342
  ),
340
343
  # Additional args for FusedMoE
341
344
  **(
342
345
  dict(
343
346
  enable_flashinfer_cutlass_moe=True,
344
- enable_ep_moe=global_server_args_dict["enable_ep_moe"],
345
347
  )
346
348
  if global_server_args_dict["enable_flashinfer_cutlass_moe"]
347
349
  else {}
@@ -374,7 +376,7 @@ class DeepseekV2MoE(nn.Module):
374
376
  prefix=add_prefix("shared_experts", prefix),
375
377
  **(
376
378
  dict(tp_rank=0, tp_size=1)
377
- if global_server_args_dict["enable_deepep_moe"]
379
+ if global_server_args_dict["moe_a2a_backend"].is_deepep()
378
380
  else {}
379
381
  ),
380
382
  )
@@ -404,9 +406,9 @@ class DeepseekV2MoE(nn.Module):
404
406
 
405
407
  self.top_k = config.num_experts_per_tok
406
408
 
407
- if global_server_args_dict["enable_deepep_moe"]:
409
+ if global_server_args_dict["moe_a2a_backend"].is_deepep():
408
410
  # TODO: we will support tp < ep in the future
409
- self.ep_size = get_tensor_model_parallel_world_size()
411
+ self.ep_size = get_moe_expert_parallel_world_size()
410
412
  self.num_experts = (
411
413
  config.n_routed_experts
412
414
  + global_server_args_dict["ep_num_redundant_experts"]
@@ -428,12 +430,12 @@ class DeepseekV2MoE(nn.Module):
428
430
  num_local_experts=config.n_routed_experts // self.tp_size,
429
431
  hidden_size=config.hidden_size,
430
432
  params_dtype=config.torch_dtype,
431
- deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
433
+ deepep_mode=global_server_args_dict["deepep_mode"],
432
434
  async_finish=True,
433
435
  return_recv_hook=True,
434
436
  )
435
437
 
436
- self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
438
+ self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
437
439
 
438
440
  def get_moe_weights(self):
439
441
  return [
@@ -447,6 +449,7 @@ class DeepseekV2MoE(nn.Module):
447
449
  hidden_states: torch.Tensor,
448
450
  forward_batch: Optional[ForwardBatch] = None,
449
451
  can_fuse_mlp_allreduce: bool = False,
452
+ use_reduce_scatter: bool = False,
450
453
  ) -> torch.Tensor:
451
454
  if not self._enable_deepep_moe:
452
455
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
@@ -456,15 +459,20 @@ class DeepseekV2MoE(nn.Module):
456
459
  and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
457
460
  ):
458
461
  return self.forward_normal_dual_stream(
459
- hidden_states, can_fuse_mlp_allreduce
462
+ hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
460
463
  )
461
464
  else:
462
- return self.forward_normal(hidden_states, can_fuse_mlp_allreduce)
465
+ return self.forward_normal(
466
+ hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
467
+ )
463
468
  else:
464
469
  return self.forward_deepep(hidden_states, forward_batch)
465
470
 
466
471
  def forward_normal_dual_stream(
467
- self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
472
+ self,
473
+ hidden_states: torch.Tensor,
474
+ can_fuse_mlp_allreduce: bool = False,
475
+ use_reduce_scatter: bool = False,
468
476
  ) -> torch.Tensor:
469
477
 
470
478
  current_stream = torch.cuda.current_stream()
@@ -475,21 +483,32 @@ class DeepseekV2MoE(nn.Module):
475
483
  # router_logits: (num_tokens, n_experts)
476
484
  router_logits = self.gate(hidden_states)
477
485
  kwargs = {"hidden_states": hidden_states}
478
- if self.topk is not None:
479
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
486
+
487
+ # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
488
+ # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
489
+ if should_use_flashinfer_trtllm_moe():
490
+ kwargs["topk_output"] = (self.topk, router_logits)
480
491
  else:
481
- kwargs["router_logits"] = router_logits
492
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
493
+
482
494
  final_hidden_states = self.experts(**kwargs)
483
495
  if not _is_cuda:
484
496
  final_hidden_states *= self.routed_scaling_factor
485
497
  current_stream.wait_stream(self.alt_stream)
486
- final_hidden_states += shared_output
487
- if self.tp_size > 1 and not can_fuse_mlp_allreduce:
498
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
499
+ final_hidden_states_out = torch.empty_like(final_hidden_states)
500
+ torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
501
+ final_hidden_states = final_hidden_states_out
502
+ sm.tag(final_hidden_states)
503
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
488
504
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
489
505
  return final_hidden_states
490
506
 
491
507
  def forward_normal(
492
- self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
508
+ self,
509
+ hidden_states: torch.Tensor,
510
+ can_fuse_mlp_allreduce: bool = False,
511
+ use_reduce_scatter: bool = False,
493
512
  ) -> torch.Tensor:
494
513
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
495
514
  self.shared_experts.gate_up_proj
@@ -500,17 +519,25 @@ class DeepseekV2MoE(nn.Module):
500
519
  # router_logits: (num_tokens, n_experts)
501
520
  router_logits = self.gate(hidden_states)
502
521
  kwargs = {"hidden_states": hidden_states}
503
- if self.topk is not None:
504
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
522
+
523
+ # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
524
+ # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
525
+ if should_use_flashinfer_trtllm_moe():
526
+ kwargs["topk_output"] = (self.topk, router_logits)
505
527
  else:
506
- kwargs["router_logits"] = router_logits
528
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
529
+
507
530
  final_hidden_states = self.experts(**kwargs)
508
531
  if not _is_cuda and not _use_aiter:
509
532
  # fused in biased_grouped_topk so we can skip here
510
533
  final_hidden_states *= self.routed_scaling_factor
511
534
  if shared_output is not None:
512
- final_hidden_states = final_hidden_states + shared_output
513
- if self.tp_size > 1 and not can_fuse_mlp_allreduce:
535
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
536
+ final_hidden_states_out = torch.empty_like(final_hidden_states)
537
+ torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
538
+ final_hidden_states = final_hidden_states_out
539
+ sm.tag(final_hidden_states)
540
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
514
541
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
515
542
  return final_hidden_states
516
543
 
@@ -1812,6 +1839,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1812
1839
  layer_scatter_modes=self.layer_scatter_modes,
1813
1840
  input_layernorm=self.input_layernorm,
1814
1841
  post_attention_layernorm=self.post_attention_layernorm,
1842
+ allow_reduce_scatter=True,
1815
1843
  )
1816
1844
 
1817
1845
  def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
@@ -1874,7 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module):
1874
1902
  and not self.is_nextn
1875
1903
  )
1876
1904
 
1877
- hidden_states = self.mlp(hidden_states, forward_batch, can_fuse_mlp_allreduce)
1905
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
1906
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
1907
+ forward_batch
1908
+ )
1909
+ hidden_states = self.mlp(
1910
+ hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
1911
+ )
1878
1912
 
1879
1913
  if can_fuse_mlp_allreduce:
1880
1914
  hidden_states._sglang_needs_allreduce_fusion = True
@@ -2051,6 +2085,8 @@ class DeepseekV2Model(nn.Module):
2051
2085
 
2052
2086
 
2053
2087
  class DeepseekV2ForCausalLM(nn.Module):
2088
+ # for quark model load
2089
+ packed_modules_mapping = {}
2054
2090
 
2055
2091
  def __init__(
2056
2092
  self,
@@ -2059,6 +2095,18 @@ class DeepseekV2ForCausalLM(nn.Module):
2059
2095
  prefix: str = "",
2060
2096
  ) -> None:
2061
2097
  super().__init__()
2098
+
2099
+ # for quark model load
2100
+ # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
2101
+ self.fuse_qkv_a_proj = (
2102
+ hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
2103
+ )
2104
+ if self.fuse_qkv_a_proj:
2105
+ self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
2106
+ "q_a_proj",
2107
+ "kv_a_proj_with_mqa",
2108
+ ]
2109
+
2062
2110
  self.config = config
2063
2111
  self.tp_size = get_tensor_model_parallel_world_size()
2064
2112
  self.quant_config = quant_config
@@ -2104,11 +2152,8 @@ class DeepseekV2ForCausalLM(nn.Module):
2104
2152
  or self.config.n_shared_experts != 1
2105
2153
  ):
2106
2154
  disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
2107
- elif (
2108
- global_server_args_dict["enable_deepep_moe"]
2109
- or global_server_args_dict["enable_ep_moe"]
2110
- ):
2111
- disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
2155
+ elif get_moe_expert_parallel_world_size() > 1:
2156
+ disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
2112
2157
 
2113
2158
  if disable_reason is not None:
2114
2159
  global_server_args_dict["disable_shared_experts_fusion"] = True