sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -113,12 +113,13 @@ class Qwen2_5_VisionBlock(nn.Module):
113
113
  quant_config: Optional[QuantizationConfig] = None,
114
114
  prefix: str = "",
115
115
  num_dummy_heads: int = 0,
116
+ rms_norm_eps: float = 1e-6,
116
117
  ) -> None:
117
118
  super().__init__()
118
119
  if norm_layer is None:
119
120
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
120
- self.norm1 = RMSNorm(dim, eps=1e-6)
121
- self.norm2 = RMSNorm(dim, eps=1e-6)
121
+ self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
122
+ self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
122
123
 
123
124
  if attn_implementation is None:
124
125
  softmax_in_single_precision = False
@@ -517,6 +518,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
517
518
  self.logits_processor = LogitsProcessor(config)
518
519
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
519
520
 
521
+ # For EAGLE3 support
522
+ self.capture_aux_hidden_states = False
523
+
520
524
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
521
525
  pattern = MultiModalityDataPaddingPatternMultimodalTokens()
522
526
  return pattern.pad_input_tokens(input_ids, mm_inputs)
@@ -587,9 +591,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
587
591
  positions=positions,
588
592
  )
589
593
 
594
+ aux_hidden_states = None
595
+ if self.capture_aux_hidden_states:
596
+ hidden_states, aux_hidden_states = hidden_states
597
+
590
598
  if not get_embedding:
591
599
  return self.logits_processor(
592
- input_ids, hidden_states, self.lm_head, forward_batch
600
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
593
601
  )
594
602
  else:
595
603
  return self.pooler(hidden_states, forward_batch)
@@ -643,5 +651,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
643
651
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
644
652
  weight_loader(param, loaded_weight)
645
653
 
654
+ def get_embed_and_head(self):
655
+ return self.model.embed_tokens.weight, self.lm_head.weight
656
+
657
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
658
+ self.capture_aux_hidden_states = True
659
+ self.model.capture_aux_hidden_states = True
660
+ if layer_ids is None:
661
+ num_layers = self.config.num_hidden_layers
662
+ self.model.layers_to_capture = [
663
+ 2,
664
+ num_layers // 2,
665
+ num_layers - 3,
666
+ ] # Specific layers for EAGLE3 support
667
+ else:
668
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
669
+
646
670
 
647
671
  EntryClass = [Qwen2_5_VLForConditionalGeneration]
@@ -62,13 +62,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
62
62
  VocabParallelEmbedding,
63
63
  )
64
64
  from sglang.srt.managers.schedule_batch import global_server_args_dict
65
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
65
66
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
66
67
  from sglang.srt.model_loader.weight_utils import default_weight_loader
67
68
  from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
68
- from sglang.srt.utils import add_prefix, make_layers
69
+ from sglang.srt.utils import add_prefix, is_cuda, make_layers
69
70
 
70
71
  logger = logging.getLogger(__name__)
71
72
 
73
+ _is_cuda = is_cuda()
74
+
72
75
 
73
76
  class Qwen2MoeMLP(nn.Module):
74
77
  def __init__(
@@ -105,11 +108,14 @@ class Qwen2MoeMLP(nn.Module):
105
108
  def forward(
106
109
  self,
107
110
  x,
111
+ should_allreduce_fusion: bool = False,
108
112
  use_reduce_scatter: bool = False,
109
113
  ):
110
114
  gate_up, _ = self.gate_up_proj(x)
111
115
  x = self.act_fn(gate_up)
112
- x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter)
116
+ x, _ = self.down_proj(
117
+ x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
118
+ )
113
119
  return x
114
120
 
115
121
 
@@ -119,11 +125,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
119
125
  layer_id: int,
120
126
  config: PretrainedConfig,
121
127
  quant_config: Optional[QuantizationConfig] = None,
128
+ alt_stream: Optional[torch.cuda.Stream] = None,
122
129
  prefix: str = "",
123
130
  ):
124
131
  super().__init__()
125
132
  self.tp_size = get_tensor_model_parallel_world_size()
126
133
  self.layer_id = layer_id
134
+ self.alt_stream = alt_stream
127
135
  if self.tp_size > config.num_experts:
128
136
  raise ValueError(
129
137
  f"Tensor parallel size {self.tp_size} is greater than "
@@ -135,7 +143,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
135
143
  renormalize=config.norm_topk_prob,
136
144
  )
137
145
 
138
- self.experts = get_moe_impl_class()(
146
+ self.experts = get_moe_impl_class(quant_config)(
139
147
  layer_id=self.layer_id,
140
148
  top_k=config.num_experts_per_tok,
141
149
  num_experts=config.num_experts,
@@ -165,14 +173,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
165
173
  self.shared_expert = None
166
174
  self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
167
175
 
168
- def forward(
169
- self,
170
- hidden_states: torch.Tensor,
171
- forward_batch: Optional[ForwardBatch] = None,
172
- use_reduce_scatter: bool = False,
173
- ) -> torch.Tensor:
174
- num_tokens, hidden_dim = hidden_states.shape
175
- hidden_states = hidden_states.view(-1, hidden_dim)
176
+ def _forward_shared_experts(self, hidden_states: torch.Tensor):
176
177
  shared_output = None
177
178
  if self.shared_expert is not None:
178
179
  shared_output = self.shared_expert(hidden_states)
@@ -180,11 +181,52 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
180
181
  shared_output = (
181
182
  F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
182
183
  )
184
+ return shared_output
183
185
 
186
+ def _forward_router_experts(self, hidden_states: torch.Tensor):
184
187
  # router_logits: (num_tokens, n_experts)
185
188
  router_logits, _ = self.gate(hidden_states)
186
189
  topk_output = self.topk(hidden_states, router_logits)
187
- final_hidden_states = self.experts(hidden_states, topk_output)
190
+ return self.experts(hidden_states, topk_output)
191
+
192
+ def forward_normal_dual_stream(
193
+ self,
194
+ hidden_states: torch.Tensor,
195
+ ) -> torch.Tensor:
196
+ current_stream = torch.cuda.current_stream()
197
+ self.alt_stream.wait_stream(current_stream)
198
+ shared_output = self._forward_shared_experts(hidden_states.clone())
199
+
200
+ with torch.cuda.stream(self.alt_stream):
201
+ router_output = self._forward_router_experts(hidden_states)
202
+
203
+ current_stream.wait_stream(self.alt_stream)
204
+
205
+ return router_output, shared_output
206
+
207
+ def forward(
208
+ self,
209
+ hidden_states: torch.Tensor,
210
+ forward_batch: Optional[ForwardBatch] = None,
211
+ use_reduce_scatter: bool = False,
212
+ ) -> torch.Tensor:
213
+ num_tokens, hidden_dim = hidden_states.shape
214
+ hidden_states = hidden_states.view(-1, hidden_dim)
215
+
216
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
217
+ if (
218
+ self.alt_stream is not None
219
+ and hidden_states.shape[0] > 0
220
+ and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
221
+ and get_is_capture_mode()
222
+ ):
223
+ final_hidden_states, shared_output = self.forward_normal_dual_stream(
224
+ hidden_states
225
+ )
226
+ else:
227
+ shared_output = self._forward_shared_experts(hidden_states)
228
+ final_hidden_states = self._forward_router_experts(hidden_states)
229
+
188
230
  if shared_output is not None:
189
231
  final_hidden_states = final_hidden_states + shared_output
190
232
  if self.tp_size > 1 and not use_reduce_scatter:
@@ -343,6 +385,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
343
385
  layer_id=layer_id,
344
386
  config=config,
345
387
  quant_config=quant_config,
388
+ alt_stream=alt_stream,
346
389
  prefix=add_prefix("mlp", prefix),
347
390
  )
348
391
  else:
@@ -525,8 +568,12 @@ class Qwen2MoeForCausalLM(nn.Module):
525
568
  self.pp_group = get_pp_group()
526
569
  self.config = config
527
570
  self.quant_config = quant_config
571
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
528
572
  self.model = Qwen2MoeModel(
529
- config, quant_config, prefix=add_prefix("model", prefix)
573
+ config,
574
+ quant_config,
575
+ prefix=add_prefix("model", prefix),
576
+ alt_stream=alt_stream,
530
577
  )
531
578
  self.lm_head = ParallelLMHead(
532
579
  config.vocab_size,
@@ -24,7 +24,10 @@ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
24
24
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
25
25
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
26
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
27
- from sglang.srt.model_loader.weight_utils import default_weight_loader
27
+ from sglang.srt.model_loader.weight_utils import (
28
+ default_weight_loader,
29
+ maybe_remap_kv_scale_name,
30
+ )
28
31
  from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
29
32
  from sglang.srt.models.qwen2 import Qwen2Model
30
33
  from sglang.srt.utils import add_prefix, is_cuda
@@ -458,7 +461,10 @@ class Qwen3ForCausalLM(nn.Module):
458
461
  continue
459
462
  if name.startswith("model.vision_tower") and name not in params_dict:
460
463
  continue
461
-
464
+ if "scale" in name:
465
+ name = maybe_remap_kv_scale_name(name, params_dict)
466
+ if name is None:
467
+ continue
462
468
  for param_name, weight_name, shard_id in stacked_params_mapping:
463
469
  if weight_name not in name:
464
470
  continue
@@ -42,7 +42,10 @@ from sglang.srt.layers.linear import (
42
42
  RowParallelLinear,
43
43
  )
44
44
  from sglang.srt.layers.logits_processor import LogitsProcessor
45
- from sglang.srt.layers.moe import get_moe_a2a_backend
45
+ from sglang.srt.layers.moe import (
46
+ get_moe_a2a_backend,
47
+ should_use_flashinfer_cutlass_moe_fp4_allgather,
48
+ )
46
49
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
47
50
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
48
51
  from sglang.srt.layers.moe.topk import TopK
@@ -57,10 +60,17 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
57
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
58
61
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
59
62
  from sglang.srt.models.qwen2_moe import Qwen2MoeModel
60
- from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
63
+ from sglang.srt.utils import (
64
+ add_prefix,
65
+ is_cuda,
66
+ is_flashinfer_available,
67
+ is_non_idle_and_non_empty,
68
+ )
61
69
 
62
70
  Qwen3MoeConfig = None
63
71
 
72
+ _is_flashinfer_available = is_flashinfer_available()
73
+
64
74
  logger = logging.getLogger(__name__)
65
75
  _is_cuda = is_cuda()
66
76
 
@@ -88,7 +98,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
88
98
  use_grouped_topk=False,
89
99
  )
90
100
 
91
- self.experts = get_moe_impl_class()(
101
+ self.experts = get_moe_impl_class(quant_config)(
92
102
  num_experts=config.num_experts
93
103
  + global_server_args_dict["ep_num_redundant_experts"],
94
104
  top_k=config.num_experts_per_tok,
@@ -119,11 +129,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
119
129
  self,
120
130
  hidden_states: torch.Tensor,
121
131
  forward_batch: Optional[ForwardBatch] = None,
132
+ should_allreduce_fusion: bool = False,
122
133
  use_reduce_scatter: bool = False,
123
134
  ) -> torch.Tensor:
124
135
 
125
136
  if not get_moe_a2a_backend().is_deepep():
126
- return self.forward_normal(hidden_states, use_reduce_scatter)
137
+ return self.forward_normal(
138
+ hidden_states, should_allreduce_fusion, use_reduce_scatter
139
+ )
127
140
  else:
128
141
  return self.forward_deepep(hidden_states, forward_batch)
129
142
 
@@ -137,6 +150,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
137
150
  def forward_normal(
138
151
  self,
139
152
  hidden_states: torch.Tensor,
153
+ should_allreduce_fusion: bool = False,
140
154
  use_reduce_scatter: bool = False,
141
155
  ) -> torch.Tensor:
142
156
  num_tokens, hidden_dim = hidden_states.shape
@@ -146,7 +160,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
146
160
  router_logits, _ = self.gate(hidden_states)
147
161
  topk_output = self.topk(hidden_states, router_logits)
148
162
  final_hidden_states = self.experts(hidden_states, topk_output)
149
- if self.tp_size > 1 and not use_reduce_scatter:
163
+ if (
164
+ self.tp_size > 1
165
+ and not should_allreduce_fusion
166
+ and not use_reduce_scatter
167
+ and not should_use_flashinfer_cutlass_moe_fp4_allgather()
168
+ ):
150
169
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
151
170
 
152
171
  return final_hidden_states.view(num_tokens, hidden_dim)
@@ -500,6 +519,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
500
519
  input_layernorm=self.input_layernorm,
501
520
  post_attention_layernorm=self.post_attention_layernorm,
502
521
  allow_reduce_scatter=True,
522
+ is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1),
503
523
  )
504
524
 
505
525
  def forward(
@@ -525,17 +545,28 @@ class Qwen3MoeDecoderLayer(nn.Module):
525
545
  hidden_states, residual, forward_batch
526
546
  )
527
547
 
548
+ should_allreduce_fusion = (
549
+ self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
550
+ forward_batch
551
+ )
552
+ )
553
+
528
554
  # For DP with padding, reduce scatter can be used instead of all-reduce.
529
555
  use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
530
556
  forward_batch
531
557
  )
532
558
 
533
- hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
534
-
535
- hidden_states, residual = self.layer_communicator.postprocess_layer(
536
- hidden_states, residual, forward_batch
559
+ hidden_states = self.mlp(
560
+ hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
537
561
  )
538
562
 
563
+ if should_allreduce_fusion:
564
+ hidden_states._sglang_needs_allreduce_fusion = True
565
+ else:
566
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
567
+ hidden_states, residual, forward_batch
568
+ )
569
+
539
570
  return hidden_states, residual
540
571
 
541
572
  def op_comm_prepare_attn(