sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -40,7 +40,6 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
40
40
  Qwen2_5_VisionRotaryEmbedding,
41
41
  )
42
42
 
43
- from sglang.srt.hf_transformers_utils import get_processor
44
43
  from sglang.srt.layers.attention.vision import VisionAttention
45
44
  from sglang.srt.layers.layernorm import RMSNorm
46
45
  from sglang.srt.layers.linear import (
@@ -61,6 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
61
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
62
61
  from sglang.srt.models.qwen2 import Qwen2Model
63
62
  from sglang.srt.utils import add_prefix
63
+ from sglang.srt.utils.hf_transformers_utils import get_processor
64
64
 
65
65
  logger = logging.getLogger(__name__)
66
66
 
@@ -113,12 +113,13 @@ class Qwen2_5_VisionBlock(nn.Module):
113
113
  quant_config: Optional[QuantizationConfig] = None,
114
114
  prefix: str = "",
115
115
  num_dummy_heads: int = 0,
116
+ rms_norm_eps: float = 1e-6,
116
117
  ) -> None:
117
118
  super().__init__()
118
119
  if norm_layer is None:
119
120
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
120
- self.norm1 = RMSNorm(dim, eps=1e-6)
121
- self.norm2 = RMSNorm(dim, eps=1e-6)
121
+ self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
122
+ self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
122
123
 
123
124
  if attn_implementation is None:
124
125
  softmax_in_single_precision = False
@@ -264,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
264
265
  self.fullatt_block_indexes = vision_config.fullatt_block_indexes
265
266
  self.window_size = vision_config.window_size
266
267
  self.patch_size = vision_config.patch_size
267
- mlp_hidden_size: int = vision_config.intermediate_size
268
+ mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8
268
269
  self.patch_embed = Qwen2_5_VisionPatchEmbed(
269
270
  patch_size=patch_size,
270
271
  temporal_patch_size=temporal_patch_size,
@@ -517,6 +518,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
517
518
  self.logits_processor = LogitsProcessor(config)
518
519
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
519
520
 
521
+ # For EAGLE3 support
522
+ self.capture_aux_hidden_states = False
523
+
520
524
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
521
525
  pattern = MultiModalityDataPaddingPatternMultimodalTokens()
522
526
  return pattern.pad_input_tokens(input_ids, mm_inputs)
@@ -587,9 +591,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
587
591
  positions=positions,
588
592
  )
589
593
 
594
+ aux_hidden_states = None
595
+ if self.capture_aux_hidden_states:
596
+ hidden_states, aux_hidden_states = hidden_states
597
+
590
598
  if not get_embedding:
591
599
  return self.logits_processor(
592
- input_ids, hidden_states, self.lm_head, forward_batch
600
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
593
601
  )
594
602
  else:
595
603
  return self.pooler(hidden_states, forward_batch)
@@ -643,5 +651,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
643
651
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
644
652
  weight_loader(param, loaded_weight)
645
653
 
654
+ def get_embed_and_head(self):
655
+ return self.model.embed_tokens.weight, self.lm_head.weight
656
+
657
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
658
+ self.capture_aux_hidden_states = True
659
+ self.model.capture_aux_hidden_states = True
660
+ if layer_ids is None:
661
+ num_layers = self.config.num_hidden_layers
662
+ self.model.layers_to_capture = [
663
+ 2,
664
+ num_layers // 2,
665
+ num_layers - 3,
666
+ ] # Specific layers for EAGLE3 support
667
+ else:
668
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
669
+
646
670
 
647
671
  EntryClass = [Qwen2_5_VLForConditionalGeneration]
@@ -39,7 +39,6 @@ from transformers.models.qwen2_audio.modeling_qwen2_audio import (
39
39
  Qwen2AudioMultiModalProjector,
40
40
  )
41
41
 
42
- from sglang.srt.hf_transformers_utils import get_processor
43
42
  from sglang.srt.layers.activation import QuickGELU
44
43
  from sglang.srt.layers.attention.vision import VisionAttention
45
44
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
@@ -61,6 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
61
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
62
61
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
63
62
  from sglang.srt.utils import add_prefix
63
+ from sglang.srt.utils.hf_transformers_utils import get_processor
64
64
 
65
65
  logger = logging.getLogger(__name__)
66
66
 
@@ -25,12 +25,14 @@ from torch import nn
25
25
  from transformers import PretrainedConfig
26
26
 
27
27
  from sglang.srt.distributed import (
28
+ get_moe_expert_parallel_world_size,
28
29
  get_pp_group,
29
30
  get_tensor_model_parallel_world_size,
30
31
  tensor_model_parallel_all_reduce,
31
32
  )
32
33
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
33
34
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
35
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
34
36
  from sglang.srt.layers.activation import SiluAndMul
35
37
  from sglang.srt.layers.communicator import (
36
38
  LayerCommunicator,
@@ -50,6 +52,7 @@ from sglang.srt.layers.linear import (
50
52
  RowParallelLinear,
51
53
  )
52
54
  from sglang.srt.layers.logits_processor import LogitsProcessor
55
+ from sglang.srt.layers.moe import get_moe_a2a_backend
53
56
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
54
57
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
55
58
  from sglang.srt.layers.moe.topk import TopK
@@ -62,13 +65,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
62
65
  VocabParallelEmbedding,
63
66
  )
64
67
  from sglang.srt.managers.schedule_batch import global_server_args_dict
68
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
65
69
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
66
70
  from sglang.srt.model_loader.weight_utils import default_weight_loader
67
71
  from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
68
- from sglang.srt.utils import add_prefix, make_layers
72
+ from sglang.srt.utils import add_prefix, is_cuda, make_layers
69
73
 
70
74
  logger = logging.getLogger(__name__)
71
75
 
76
+ _is_cuda = is_cuda()
77
+
72
78
 
73
79
  class Qwen2MoeMLP(nn.Module):
74
80
  def __init__(
@@ -79,6 +85,8 @@ class Qwen2MoeMLP(nn.Module):
79
85
  quant_config: Optional[QuantizationConfig] = None,
80
86
  reduce_results: bool = True,
81
87
  prefix: str = "",
88
+ tp_rank: Optional[int] = None,
89
+ tp_size: Optional[int] = None,
82
90
  ) -> None:
83
91
  super().__init__()
84
92
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -87,6 +95,8 @@ class Qwen2MoeMLP(nn.Module):
87
95
  bias=False,
88
96
  quant_config=quant_config,
89
97
  prefix=add_prefix("gate_up_proj", prefix),
98
+ tp_rank=tp_rank,
99
+ tp_size=tp_size,
90
100
  )
91
101
  self.down_proj = RowParallelLinear(
92
102
  intermediate_size,
@@ -95,6 +105,8 @@ class Qwen2MoeMLP(nn.Module):
95
105
  quant_config=quant_config,
96
106
  reduce_results=reduce_results,
97
107
  prefix=add_prefix("down_proj", prefix),
108
+ tp_rank=tp_rank,
109
+ tp_size=tp_size,
98
110
  )
99
111
  if hidden_act != "silu":
100
112
  raise ValueError(
@@ -105,11 +117,14 @@ class Qwen2MoeMLP(nn.Module):
105
117
  def forward(
106
118
  self,
107
119
  x,
120
+ should_allreduce_fusion: bool = False,
108
121
  use_reduce_scatter: bool = False,
109
122
  ):
110
123
  gate_up, _ = self.gate_up_proj(x)
111
124
  x = self.act_fn(gate_up)
112
- x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter)
125
+ x, _ = self.down_proj(
126
+ x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
127
+ )
113
128
  return x
114
129
 
115
130
 
@@ -119,11 +134,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
119
134
  layer_id: int,
120
135
  config: PretrainedConfig,
121
136
  quant_config: Optional[QuantizationConfig] = None,
137
+ alt_stream: Optional[torch.cuda.Stream] = None,
122
138
  prefix: str = "",
123
139
  ):
124
140
  super().__init__()
125
141
  self.tp_size = get_tensor_model_parallel_world_size()
126
142
  self.layer_id = layer_id
143
+ self.alt_stream = alt_stream
127
144
  if self.tp_size > config.num_experts:
128
145
  raise ValueError(
129
146
  f"Tensor parallel size {self.tp_size} is greater than "
@@ -135,10 +152,11 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
135
152
  renormalize=config.norm_topk_prob,
136
153
  )
137
154
 
138
- self.experts = get_moe_impl_class()(
155
+ self.experts = get_moe_impl_class(quant_config)(
139
156
  layer_id=self.layer_id,
140
157
  top_k=config.num_experts_per_tok,
141
- num_experts=config.num_experts,
158
+ num_experts=config.num_experts
159
+ + global_server_args_dict["ep_num_redundant_experts"],
142
160
  hidden_size=config.hidden_size,
143
161
  intermediate_size=config.moe_intermediate_size,
144
162
  quant_config=quant_config,
@@ -160,19 +178,32 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
160
178
  quant_config=quant_config,
161
179
  reduce_results=False,
162
180
  prefix=add_prefix("shared_expert", prefix),
181
+ **(
182
+ dict(tp_rank=0, tp_size=1)
183
+ if get_moe_a2a_backend().is_deepep()
184
+ else {}
185
+ ),
163
186
  )
164
187
  else:
165
188
  self.shared_expert = None
166
189
  self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
167
190
 
168
- def forward(
169
- self,
170
- hidden_states: torch.Tensor,
171
- forward_batch: Optional[ForwardBatch] = None,
172
- use_reduce_scatter: bool = False,
173
- ) -> torch.Tensor:
174
- num_tokens, hidden_dim = hidden_states.shape
175
- hidden_states = hidden_states.view(-1, hidden_dim)
191
+ if get_moe_a2a_backend().is_deepep():
192
+ # TODO: we will support tp < ep in the future
193
+ self.ep_size = get_moe_expert_parallel_world_size()
194
+ self.num_experts = (
195
+ config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
196
+ )
197
+ self.top_k = config.num_experts_per_tok
198
+
199
+ def get_moe_weights(self):
200
+ return [
201
+ x.data
202
+ for name, x in self.experts.named_parameters()
203
+ if name not in ["correction_bias"]
204
+ ]
205
+
206
+ def _forward_shared_experts(self, hidden_states: torch.Tensor):
176
207
  shared_output = None
177
208
  if self.shared_expert is not None:
178
209
  shared_output = self.shared_expert(hidden_states)
@@ -180,11 +211,85 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
180
211
  shared_output = (
181
212
  F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
182
213
  )
214
+ return shared_output
215
+
216
+ def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
217
+ shared_output = None
218
+ if hidden_states.shape[0] > 0:
219
+ # router_logits: (num_tokens, n_experts)
220
+ router_logits, _ = self.gate(hidden_states)
221
+ shared_output = self._forward_shared_experts(hidden_states)
222
+ topk_weights, topk_idx, _ = self.topk(
223
+ hidden_states,
224
+ router_logits,
225
+ num_token_non_padded=forward_batch.num_token_non_padded,
226
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
227
+ layer_id=self.layer_id,
228
+ ),
229
+ )
230
+ else:
231
+ topk_weights, topk_idx, _ = self.topk.empty_topk_output(
232
+ hidden_states.device
233
+ )
234
+ final_hidden_states = self.experts(
235
+ hidden_states=hidden_states,
236
+ topk_idx=topk_idx,
237
+ topk_weights=topk_weights,
238
+ forward_batch=forward_batch,
239
+ )
240
+
241
+ if shared_output is not None:
242
+ final_hidden_states.add_(shared_output)
183
243
 
244
+ return final_hidden_states
245
+
246
+ def _forward_router_experts(self, hidden_states: torch.Tensor):
184
247
  # router_logits: (num_tokens, n_experts)
185
248
  router_logits, _ = self.gate(hidden_states)
186
249
  topk_output = self.topk(hidden_states, router_logits)
187
- final_hidden_states = self.experts(hidden_states, topk_output)
250
+ return self.experts(hidden_states, topk_output)
251
+
252
+ def forward_normal_dual_stream(
253
+ self,
254
+ hidden_states: torch.Tensor,
255
+ ) -> torch.Tensor:
256
+ current_stream = torch.cuda.current_stream()
257
+ self.alt_stream.wait_stream(current_stream)
258
+ shared_output = self._forward_shared_experts(hidden_states.clone())
259
+
260
+ with torch.cuda.stream(self.alt_stream):
261
+ router_output = self._forward_router_experts(hidden_states)
262
+
263
+ current_stream.wait_stream(self.alt_stream)
264
+
265
+ return router_output, shared_output
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states: torch.Tensor,
270
+ forward_batch: Optional[ForwardBatch] = None,
271
+ use_reduce_scatter: bool = False,
272
+ ) -> torch.Tensor:
273
+ num_tokens, hidden_dim = hidden_states.shape
274
+ hidden_states = hidden_states.view(-1, hidden_dim)
275
+
276
+ if get_moe_a2a_backend().is_deepep():
277
+ return self._forward_deepep(hidden_states, forward_batch)
278
+
279
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
280
+ if (
281
+ self.alt_stream is not None
282
+ and hidden_states.shape[0] > 0
283
+ and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
284
+ and get_is_capture_mode()
285
+ ):
286
+ final_hidden_states, shared_output = self.forward_normal_dual_stream(
287
+ hidden_states
288
+ )
289
+ else:
290
+ shared_output = self._forward_shared_experts(hidden_states)
291
+ final_hidden_states = self._forward_router_experts(hidden_states)
292
+
188
293
  if shared_output is not None:
189
294
  final_hidden_states = final_hidden_states + shared_output
190
295
  if self.tp_size > 1 and not use_reduce_scatter:
@@ -343,6 +448,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
343
448
  layer_id=layer_id,
344
449
  config=config,
345
450
  quant_config=quant_config,
451
+ alt_stream=alt_stream,
346
452
  prefix=add_prefix("mlp", prefix),
347
453
  )
348
454
  else:
@@ -525,8 +631,12 @@ class Qwen2MoeForCausalLM(nn.Module):
525
631
  self.pp_group = get_pp_group()
526
632
  self.config = config
527
633
  self.quant_config = quant_config
634
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
528
635
  self.model = Qwen2MoeModel(
529
- config, quant_config, prefix=add_prefix("model", prefix)
636
+ config,
637
+ quant_config,
638
+ prefix=add_prefix("model", prefix),
639
+ alt_stream=alt_stream,
530
640
  )
531
641
  self.lm_head = ParallelLMHead(
532
642
  config.vocab_size,
@@ -33,7 +33,6 @@ from einops import rearrange
33
33
  from transformers import Qwen2VLConfig
34
34
  from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
35
35
 
36
- from sglang.srt.hf_transformers_utils import get_processor
37
36
  from sglang.srt.layers.activation import QuickGELU
38
37
  from sglang.srt.layers.attention.vision import VisionAttention
39
38
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
@@ -50,6 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
49
  from sglang.srt.model_loader.weight_utils import default_weight_loader
51
50
  from sglang.srt.models.qwen2 import Qwen2Model
52
51
  from sglang.srt.utils import add_prefix
52
+ from sglang.srt.utils.hf_transformers_utils import get_processor
53
53
 
54
54
  logger = logging.getLogger(__name__)
55
55
 
@@ -1,6 +1,5 @@
1
1
  # Adapted from qwen2.py
2
2
  import logging
3
- from functools import partial
4
3
  from typing import Any, Dict, Iterable, List, Optional, Tuple
5
4
 
6
5
  import torch
@@ -24,15 +23,25 @@ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
24
23
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
25
24
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
26
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
27
- from sglang.srt.model_loader.weight_utils import default_weight_loader
26
+ from sglang.srt.model_loader.weight_utils import (
27
+ default_weight_loader,
28
+ maybe_remap_kv_scale_name,
29
+ )
28
30
  from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
29
31
  from sglang.srt.models.qwen2 import Qwen2Model
30
- from sglang.srt.utils import add_prefix, is_cuda
32
+ from sglang.srt.utils import (
33
+ add_prefix,
34
+ get_cmo_stream,
35
+ is_cuda,
36
+ is_npu,
37
+ wait_cmo_stream,
38
+ )
31
39
 
32
40
  Qwen3Config = None
33
41
 
34
42
  logger = logging.getLogger(__name__)
35
43
  _is_cuda = is_cuda()
44
+ _is_npu = is_npu()
36
45
 
37
46
 
38
47
  class Qwen3Attention(nn.Module):
@@ -232,9 +241,18 @@ class Qwen3DecoderLayer(nn.Module):
232
241
 
233
242
  # Fully Connected
234
243
  hidden_states, residual = self.layer_communicator.prepare_mlp(
235
- hidden_states, residual, forward_batch
244
+ hidden_states,
245
+ residual,
246
+ forward_batch,
247
+ cache=(
248
+ [self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
249
+ if _is_npu
250
+ else None
251
+ ),
236
252
  )
237
253
  hidden_states = self.mlp(hidden_states)
254
+ if _is_npu and get_cmo_stream():
255
+ wait_cmo_stream()
238
256
  hidden_states, residual = self.layer_communicator.postprocess_layer(
239
257
  hidden_states, residual, forward_batch
240
258
  )
@@ -458,7 +476,10 @@ class Qwen3ForCausalLM(nn.Module):
458
476
  continue
459
477
  if name.startswith("model.vision_tower") and name not in params_dict:
460
478
  continue
461
-
479
+ if "scale" in name:
480
+ name = maybe_remap_kv_scale_name(name, params_dict)
481
+ if name is None:
482
+ continue
462
483
  for param_name, weight_name, shard_id in stacked_params_mapping:
463
484
  if weight_name not in name:
464
485
  continue
@@ -42,13 +42,16 @@ from sglang.srt.layers.linear import (
42
42
  RowParallelLinear,
43
43
  )
44
44
  from sglang.srt.layers.logits_processor import LogitsProcessor
45
- from sglang.srt.layers.moe import get_moe_a2a_backend
45
+ from sglang.srt.layers.moe import (
46
+ get_moe_a2a_backend,
47
+ should_use_flashinfer_cutlass_moe_fp4_allgather,
48
+ )
46
49
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
47
50
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
48
51
  from sglang.srt.layers.moe.topk import TopK
49
52
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
50
53
  from sglang.srt.layers.radix_attention import RadixAttention
51
- from sglang.srt.layers.rotary_embedding import get_rope
54
+ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
52
55
  from sglang.srt.layers.utils import get_layer_id
53
56
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
54
57
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -57,10 +60,21 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
57
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
58
61
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
59
62
  from sglang.srt.models.qwen2_moe import Qwen2MoeModel
60
- from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
63
+ from sglang.srt.models.utils import (
64
+ create_fused_set_kv_buffer_arg,
65
+ enable_fused_set_kv_buffer,
66
+ )
67
+ from sglang.srt.utils import (
68
+ add_prefix,
69
+ is_cuda,
70
+ is_flashinfer_available,
71
+ is_non_idle_and_non_empty,
72
+ )
61
73
 
62
74
  Qwen3MoeConfig = None
63
75
 
76
+ _is_flashinfer_available = is_flashinfer_available()
77
+
64
78
  logger = logging.getLogger(__name__)
65
79
  _is_cuda = is_cuda()
66
80
 
@@ -88,7 +102,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
88
102
  use_grouped_topk=False,
89
103
  )
90
104
 
91
- self.experts = get_moe_impl_class()(
105
+ self.experts = get_moe_impl_class(quant_config)(
92
106
  num_experts=config.num_experts
93
107
  + global_server_args_dict["ep_num_redundant_experts"],
94
108
  top_k=config.num_experts_per_tok,
@@ -119,11 +133,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
119
133
  self,
120
134
  hidden_states: torch.Tensor,
121
135
  forward_batch: Optional[ForwardBatch] = None,
136
+ should_allreduce_fusion: bool = False,
122
137
  use_reduce_scatter: bool = False,
123
138
  ) -> torch.Tensor:
124
139
 
125
140
  if not get_moe_a2a_backend().is_deepep():
126
- return self.forward_normal(hidden_states, use_reduce_scatter)
141
+ return self.forward_normal(
142
+ hidden_states, should_allreduce_fusion, use_reduce_scatter
143
+ )
127
144
  else:
128
145
  return self.forward_deepep(hidden_states, forward_batch)
129
146
 
@@ -137,6 +154,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
137
154
  def forward_normal(
138
155
  self,
139
156
  hidden_states: torch.Tensor,
157
+ should_allreduce_fusion: bool = False,
140
158
  use_reduce_scatter: bool = False,
141
159
  ) -> torch.Tensor:
142
160
  num_tokens, hidden_dim = hidden_states.shape
@@ -146,7 +164,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
146
164
  router_logits, _ = self.gate(hidden_states)
147
165
  topk_output = self.topk(hidden_states, router_logits)
148
166
  final_hidden_states = self.experts(hidden_states, topk_output)
149
- if self.tp_size > 1 and not use_reduce_scatter:
167
+ if (
168
+ self.tp_size > 1
169
+ and not should_allreduce_fusion
170
+ and not use_reduce_scatter
171
+ and not should_use_flashinfer_cutlass_moe_fp4_allgather()
172
+ ):
150
173
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
151
174
 
152
175
  return final_hidden_states.view(num_tokens, hidden_dim)
@@ -335,6 +358,10 @@ class Qwen3MoeAttention(nn.Module):
335
358
  rope_scaling=rope_scaling,
336
359
  dual_chunk_attention_config=dual_chunk_attention_config,
337
360
  )
361
+ self.compatible_with_fused_kv_buffer = (
362
+ False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
363
+ )
364
+
338
365
  self.attn = RadixAttention(
339
366
  self.num_heads,
340
367
  self.head_dim,
@@ -393,7 +420,21 @@ class Qwen3MoeAttention(nn.Module):
393
420
  qkv, _ = self.qkv_proj(hidden_states)
394
421
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
395
422
  q, k = self._apply_qk_norm(q, k)
396
- q, k = self.rotary_emb(positions, q, k)
423
+ q, k = self.rotary_emb(
424
+ positions,
425
+ q,
426
+ k,
427
+ fused_set_kv_buffer_arg=(
428
+ create_fused_set_kv_buffer_arg(
429
+ value=v,
430
+ layer=self.attn,
431
+ forward_batch=forward_batch,
432
+ )
433
+ if enable_fused_set_kv_buffer(forward_batch)
434
+ and self.compatible_with_fused_kv_buffer
435
+ else None
436
+ ),
437
+ )
397
438
  inner_state = q, k, v, forward_batch
398
439
  return None, forward_batch, inner_state
399
440
 
@@ -401,7 +442,13 @@ class Qwen3MoeAttention(nn.Module):
401
442
  hidden_states, forward_batch, inner_state = intermediate_state
402
443
  if inner_state is None:
403
444
  return hidden_states
404
- attn_output = self.attn(*inner_state)
445
+ attn_output = self.attn(
446
+ *inner_state,
447
+ save_kv_cache=not (
448
+ enable_fused_set_kv_buffer(forward_batch)
449
+ and self.compatible_with_fused_kv_buffer
450
+ ),
451
+ )
405
452
  output, _ = self.o_proj(attn_output)
406
453
  return output
407
454
 
@@ -500,6 +547,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
500
547
  input_layernorm=self.input_layernorm,
501
548
  post_attention_layernorm=self.post_attention_layernorm,
502
549
  allow_reduce_scatter=True,
550
+ is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1),
503
551
  )
504
552
 
505
553
  def forward(
@@ -525,17 +573,28 @@ class Qwen3MoeDecoderLayer(nn.Module):
525
573
  hidden_states, residual, forward_batch
526
574
  )
527
575
 
576
+ should_allreduce_fusion = (
577
+ self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
578
+ forward_batch
579
+ )
580
+ )
581
+
528
582
  # For DP with padding, reduce scatter can be used instead of all-reduce.
529
583
  use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
530
584
  forward_batch
531
585
  )
532
586
 
533
- hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
534
-
535
- hidden_states, residual = self.layer_communicator.postprocess_layer(
536
- hidden_states, residual, forward_batch
587
+ hidden_states = self.mlp(
588
+ hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
537
589
  )
538
590
 
591
+ if should_allreduce_fusion:
592
+ hidden_states._sglang_needs_allreduce_fusion = True
593
+ else:
594
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
595
+ hidden_states, residual, forward_batch
596
+ )
597
+
539
598
  return hidden_states, residual
540
599
 
541
600
  def op_comm_prepare_attn(