sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -113,12 +113,13 @@ class Qwen2_5_VisionBlock(nn.Module):
113
113
  quant_config: Optional[QuantizationConfig] = None,
114
114
  prefix: str = "",
115
115
  num_dummy_heads: int = 0,
116
+ rms_norm_eps: float = 1e-6,
116
117
  ) -> None:
117
118
  super().__init__()
118
119
  if norm_layer is None:
119
120
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
120
- self.norm1 = RMSNorm(dim, eps=1e-6)
121
- self.norm2 = RMSNorm(dim, eps=1e-6)
121
+ self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
122
+ self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
122
123
 
123
124
  if attn_implementation is None:
124
125
  softmax_in_single_precision = False
@@ -517,6 +518,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
517
518
  self.logits_processor = LogitsProcessor(config)
518
519
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
519
520
 
521
+ # For EAGLE3 support
522
+ self.capture_aux_hidden_states = False
523
+
520
524
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
521
525
  pattern = MultiModalityDataPaddingPatternMultimodalTokens()
522
526
  return pattern.pad_input_tokens(input_ids, mm_inputs)
@@ -587,9 +591,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
587
591
  positions=positions,
588
592
  )
589
593
 
594
+ aux_hidden_states = None
595
+ if self.capture_aux_hidden_states:
596
+ hidden_states, aux_hidden_states = hidden_states
597
+
590
598
  if not get_embedding:
591
599
  return self.logits_processor(
592
- input_ids, hidden_states, self.lm_head, forward_batch
600
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
593
601
  )
594
602
  else:
595
603
  return self.pooler(hidden_states, forward_batch)
@@ -643,5 +651,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
643
651
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
644
652
  weight_loader(param, loaded_weight)
645
653
 
654
+ def get_embed_and_head(self):
655
+ return self.model.embed_tokens.weight, self.lm_head.weight
656
+
657
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
658
+ self.capture_aux_hidden_states = True
659
+ self.model.capture_aux_hidden_states = True
660
+ if layer_ids is None:
661
+ num_layers = self.config.num_hidden_layers
662
+ self.model.layers_to_capture = [
663
+ 2,
664
+ num_layers // 2,
665
+ num_layers - 3,
666
+ ] # Specific layers for EAGLE3 support
667
+ else:
668
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
669
+
646
670
 
647
671
  EntryClass = [Qwen2_5_VLForConditionalGeneration]
@@ -62,13 +62,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
62
62
  VocabParallelEmbedding,
63
63
  )
64
64
  from sglang.srt.managers.schedule_batch import global_server_args_dict
65
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
65
66
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
66
67
  from sglang.srt.model_loader.weight_utils import default_weight_loader
67
68
  from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
68
- from sglang.srt.utils import add_prefix, make_layers
69
+ from sglang.srt.utils import add_prefix, is_cuda, make_layers
69
70
 
70
71
  logger = logging.getLogger(__name__)
71
72
 
73
+ _is_cuda = is_cuda()
74
+
72
75
 
73
76
  class Qwen2MoeMLP(nn.Module):
74
77
  def __init__(
@@ -122,11 +125,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
122
125
  layer_id: int,
123
126
  config: PretrainedConfig,
124
127
  quant_config: Optional[QuantizationConfig] = None,
128
+ alt_stream: Optional[torch.cuda.Stream] = None,
125
129
  prefix: str = "",
126
130
  ):
127
131
  super().__init__()
128
132
  self.tp_size = get_tensor_model_parallel_world_size()
129
133
  self.layer_id = layer_id
134
+ self.alt_stream = alt_stream
130
135
  if self.tp_size > config.num_experts:
131
136
  raise ValueError(
132
137
  f"Tensor parallel size {self.tp_size} is greater than "
@@ -138,7 +143,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
138
143
  renormalize=config.norm_topk_prob,
139
144
  )
140
145
 
141
- self.experts = get_moe_impl_class()(
146
+ self.experts = get_moe_impl_class(quant_config)(
142
147
  layer_id=self.layer_id,
143
148
  top_k=config.num_experts_per_tok,
144
149
  num_experts=config.num_experts,
@@ -168,14 +173,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
168
173
  self.shared_expert = None
169
174
  self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
170
175
 
171
- def forward(
172
- self,
173
- hidden_states: torch.Tensor,
174
- forward_batch: Optional[ForwardBatch] = None,
175
- use_reduce_scatter: bool = False,
176
- ) -> torch.Tensor:
177
- num_tokens, hidden_dim = hidden_states.shape
178
- hidden_states = hidden_states.view(-1, hidden_dim)
176
+ def _forward_shared_experts(self, hidden_states: torch.Tensor):
179
177
  shared_output = None
180
178
  if self.shared_expert is not None:
181
179
  shared_output = self.shared_expert(hidden_states)
@@ -183,11 +181,52 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
183
181
  shared_output = (
184
182
  F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
185
183
  )
184
+ return shared_output
186
185
 
186
+ def _forward_router_experts(self, hidden_states: torch.Tensor):
187
187
  # router_logits: (num_tokens, n_experts)
188
188
  router_logits, _ = self.gate(hidden_states)
189
189
  topk_output = self.topk(hidden_states, router_logits)
190
- final_hidden_states = self.experts(hidden_states, topk_output)
190
+ return self.experts(hidden_states, topk_output)
191
+
192
+ def forward_normal_dual_stream(
193
+ self,
194
+ hidden_states: torch.Tensor,
195
+ ) -> torch.Tensor:
196
+ current_stream = torch.cuda.current_stream()
197
+ self.alt_stream.wait_stream(current_stream)
198
+ shared_output = self._forward_shared_experts(hidden_states.clone())
199
+
200
+ with torch.cuda.stream(self.alt_stream):
201
+ router_output = self._forward_router_experts(hidden_states)
202
+
203
+ current_stream.wait_stream(self.alt_stream)
204
+
205
+ return router_output, shared_output
206
+
207
+ def forward(
208
+ self,
209
+ hidden_states: torch.Tensor,
210
+ forward_batch: Optional[ForwardBatch] = None,
211
+ use_reduce_scatter: bool = False,
212
+ ) -> torch.Tensor:
213
+ num_tokens, hidden_dim = hidden_states.shape
214
+ hidden_states = hidden_states.view(-1, hidden_dim)
215
+
216
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
217
+ if (
218
+ self.alt_stream is not None
219
+ and hidden_states.shape[0] > 0
220
+ and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
221
+ and get_is_capture_mode()
222
+ ):
223
+ final_hidden_states, shared_output = self.forward_normal_dual_stream(
224
+ hidden_states
225
+ )
226
+ else:
227
+ shared_output = self._forward_shared_experts(hidden_states)
228
+ final_hidden_states = self._forward_router_experts(hidden_states)
229
+
191
230
  if shared_output is not None:
192
231
  final_hidden_states = final_hidden_states + shared_output
193
232
  if self.tp_size > 1 and not use_reduce_scatter:
@@ -346,6 +385,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
346
385
  layer_id=layer_id,
347
386
  config=config,
348
387
  quant_config=quant_config,
388
+ alt_stream=alt_stream,
349
389
  prefix=add_prefix("mlp", prefix),
350
390
  )
351
391
  else:
@@ -528,8 +568,12 @@ class Qwen2MoeForCausalLM(nn.Module):
528
568
  self.pp_group = get_pp_group()
529
569
  self.config = config
530
570
  self.quant_config = quant_config
571
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
531
572
  self.model = Qwen2MoeModel(
532
- config, quant_config, prefix=add_prefix("model", prefix)
573
+ config,
574
+ quant_config,
575
+ prefix=add_prefix("model", prefix),
576
+ alt_stream=alt_stream,
533
577
  )
534
578
  self.lm_head = ParallelLMHead(
535
579
  config.vocab_size,
@@ -98,7 +98,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
98
98
  use_grouped_topk=False,
99
99
  )
100
100
 
101
- self.experts = get_moe_impl_class()(
101
+ self.experts = get_moe_impl_class(quant_config)(
102
102
  num_experts=config.num_experts
103
103
  + global_server_args_dict["ep_num_redundant_experts"],
104
104
  top_k=config.num_experts_per_tok,