sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -25,30 +25,6 @@ if TYPE_CHECKING:
25
25
  def quantize(w, dtype, dev, **opt):
26
26
  if dtype == "bf16":
27
27
  return w.to(torch.bfloat16), InFlexData()
28
- elif dtype == "fp8":
29
- wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
30
- return (
31
- wq,
32
- InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)),
33
- MicroscalingCtx(),
34
- )
35
- else:
36
- assert dtype == "mx4", f"{dtype=}"
37
- swizzle_mx_scale = opt["swizzle_mx_scale"]
38
- swizzle_axis = 2 if swizzle_mx_scale else None
39
- w = w.to(torch.bfloat16)
40
- w, mx_scales, weight_scale_shape = downcast_to_mxfp(
41
- w, torch.uint8, axis=1, swizzle_axis=swizzle_axis
42
- )
43
- return (
44
- w,
45
- InFlexData(),
46
- MicroscalingCtx(
47
- weight_scale=mx_scales,
48
- swizzle_mx=swizzle_mx_scale,
49
- actual_weight_scale_shape=weight_scale_shape,
50
- ),
51
- )
52
28
 
53
29
 
54
30
  def triton_kernel_moe_forward(
@@ -119,14 +95,14 @@ def triton_kernel_fused_experts(
119
95
  block_shape: Optional[list[int]] = None,
120
96
  ) -> torch.Tensor:
121
97
 
122
- assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
123
- assert per_channel_quant == False, "per_channel_quant is not supported"
124
- assert expert_map == None, "expert_map is not supported"
125
- assert w1_scale == None, "w1_scale is not supported"
126
- assert w2_scale == None, "w2_scale is not supported"
127
- assert a1_scale == None, "a1_scale is not supported"
128
- assert a2_scale == None, "a2_scale is not supported"
129
- assert block_shape == None, "block_shape is not supported"
98
+ assert use_fp8_w8a8 is False, "use_fp8_w8a8 is not supported"
99
+ assert per_channel_quant is False, "per_channel_quant is not supported"
100
+ assert expert_map is None, "expert_map is not supported"
101
+ assert w1_scale is None, "w1_scale is not supported"
102
+ assert w2_scale is None, "w2_scale is not supported"
103
+ assert a1_scale is None, "a1_scale is not supported"
104
+ assert a2_scale is None, "a2_scale is not supported"
105
+ assert block_shape is None, "block_shape is not supported"
130
106
 
131
107
  # type check
132
108
  assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
@@ -143,7 +119,7 @@ def triton_kernel_fused_experts(
143
119
  ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
144
120
 
145
121
  # feature check
146
- assert inplace == False, "Inplace is not supported in new triton MoE kernel"
122
+ assert inplace is False, "Inplace is not supported in new triton MoE kernel"
147
123
 
148
124
  M, K = hidden_states.shape
149
125
  E, _, N = w1.shape
@@ -264,14 +240,14 @@ def triton_kernel_fused_experts_with_bias(
264
240
  gemm1_alpha: Optional[float] = None,
265
241
  gemm1_clamp_limit: Optional[float] = None,
266
242
  ) -> torch.Tensor:
267
- assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
268
- assert per_channel_quant == False, "per_channel_quant is not supported"
269
- assert expert_map == None, "expert_map is not supported"
270
- assert w1_scale == None, "w1_scale is not supported"
271
- assert w2_scale == None, "w2_scale is not supported"
272
- assert a1_scale == None, "a1_scale is not supported"
273
- assert a2_scale == None, "a2_scale is not supported"
274
- assert block_shape == None, "block_shape is not supported"
243
+ assert use_fp8_w8a8 is False, "use_fp8_w8a8 is not supported"
244
+ assert per_channel_quant is False, "per_channel_quant is not supported"
245
+ assert expert_map is None, "expert_map is not supported"
246
+ assert w1_scale is None, "w1_scale is not supported"
247
+ assert w2_scale is None, "w2_scale is not supported"
248
+ assert a1_scale is None, "a1_scale is not supported"
249
+ assert a2_scale is None, "a2_scale is not supported"
250
+ assert block_shape is None, "block_shape is not supported"
275
251
 
276
252
  # type check
277
253
  assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
@@ -290,7 +266,7 @@ def triton_kernel_fused_experts_with_bias(
290
266
  ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
291
267
 
292
268
  # feature check
293
- assert inplace == False, "Inplace is not supported in new triton MoE kernel"
269
+ assert inplace is False, "Inplace is not supported in new triton MoE kernel"
294
270
 
295
271
  E, _, _ = w1.shape
296
272
 
@@ -0,0 +1,304 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, List, Optional
5
+
6
+ import torch
7
+
8
+ from sglang.srt.layers.moe.moe_runner.base import (
9
+ MoeQuantInfo,
10
+ MoeRunnerConfig,
11
+ MoeRunnerCore,
12
+ RunnerInput,
13
+ RunnerOutput,
14
+ register_post_permute,
15
+ register_pre_permute,
16
+ )
17
+ from sglang.srt.layers.moe.utils import MoeRunnerBackend
18
+ from sglang.srt.utils import dispose_tensor
19
+
20
+ if TYPE_CHECKING:
21
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
22
+ StandardCombineInput,
23
+ StandardDispatchOutput,
24
+ )
25
+
26
+
27
+ # TODO(kaixih@nvidia): ideally we should merge this logic into
28
+ # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
29
+ @torch.compile
30
+ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
31
+ temp = x.to(torch.float32).view(torch.int32)
32
+ exp = torch.bitwise_right_shift(temp, 23)
33
+ mant = torch.bitwise_and(temp, 0x7FFFFF)
34
+ is_ru = torch.logical_and(
35
+ torch.logical_and((mant > 0), (exp != 0xFE)),
36
+ ~torch.logical_and((exp == 0), (mant <= 0x400000)),
37
+ )
38
+ exp = torch.where(is_ru, exp + 1, exp)
39
+ new_x = exp.to(torch.uint8).view(torch.int)
40
+ return new_x.transpose(1, 2).contiguous().transpose(1, 2)
41
+
42
+
43
+ @dataclass
44
+ class DeepGemmRunnerInput(RunnerInput):
45
+ hidden_states: torch.Tensor
46
+ hidden_states_scale: torch.Tensor
47
+ masked_m: torch.Tensor
48
+ expected_m: int
49
+ use_masked_gemm: bool
50
+
51
+ @property
52
+ def runner_backend(self) -> MoeRunnerBackend:
53
+ return MoeRunnerBackend.DEEP_GEMM
54
+
55
+
56
+ @dataclass
57
+ class DeepGemmRunnerOutput(RunnerOutput):
58
+ hidden_states: torch.Tensor
59
+
60
+ @property
61
+ def runner_backend(self) -> MoeRunnerBackend:
62
+ return MoeRunnerBackend.DEEP_GEMM
63
+
64
+
65
+ @dataclass
66
+ class DeepGemmMoeQuantInfo(MoeQuantInfo):
67
+ w13_weight: torch.Tensor
68
+ w2_weight: torch.Tensor
69
+ use_fp8: bool
70
+ w13_scale: Optional[torch.Tensor] = None
71
+ w2_scale: Optional[torch.Tensor] = None
72
+ block_shape: Optional[List[int]] = None
73
+
74
+
75
+ class DeepGemmRunnerCore(MoeRunnerCore):
76
+ def __init__(self, config: MoeRunnerConfig):
77
+ super().__init__(config)
78
+ assert self.config.activation == "silu"
79
+
80
+ def run(
81
+ self,
82
+ runner_input: DeepGemmRunnerInput,
83
+ quant_info: DeepGemmMoeQuantInfo,
84
+ running_state: dict,
85
+ ) -> DeepGemmRunnerOutput:
86
+
87
+ if runner_input.use_masked_gemm:
88
+ hidden_states = self._run_masked_gemm(
89
+ runner_input,
90
+ quant_info,
91
+ running_state,
92
+ )
93
+ else:
94
+ hidden_states = self._run_contiguous_gemm(
95
+ runner_input,
96
+ quant_info,
97
+ running_state,
98
+ )
99
+ return DeepGemmRunnerOutput(hidden_states=hidden_states)
100
+
101
+ def _run_masked_gemm(
102
+ self,
103
+ runner_input: DeepGemmRunnerInput,
104
+ quant_info: DeepGemmMoeQuantInfo,
105
+ running_state: dict,
106
+ ) -> torch.Tensor:
107
+
108
+ from sglang.srt.layers import deep_gemm_wrapper
109
+ from sglang.srt.layers.moe.ep_moe.kernels import (
110
+ silu_and_mul_masked_post_quant_fwd,
111
+ )
112
+
113
+ hidden_states = runner_input.hidden_states
114
+ hidden_states_scale = runner_input.hidden_states_scale
115
+ masked_m = runner_input.masked_m
116
+ expected_m = runner_input.expected_m
117
+
118
+ w13_weight = quant_info.w13_weight
119
+ w2_weight = quant_info.w2_weight
120
+ w13_scale = quant_info.w13_scale
121
+ w2_scale = quant_info.w2_scale
122
+
123
+ hidden_states_device = running_state["hidden_states_device"]
124
+
125
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
126
+ b, s_mn, s_k = hidden_states_scale.shape
127
+ assert (
128
+ s_mn % 4 == 0 and s_k % 4 == 0
129
+ ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
130
+
131
+ # GroupGemm-0
132
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
133
+ hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
134
+ else:
135
+ hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
136
+ hidden_states_scale
137
+ )
138
+
139
+ num_groups, m, k = hidden_states.shape
140
+ n = w13_weight.size(1)
141
+ gateup_output = torch.empty(
142
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
143
+ )
144
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
145
+ (hidden_states, hidden_states_scale),
146
+ (w13_weight, w13_scale),
147
+ gateup_output,
148
+ masked_m,
149
+ expected_m,
150
+ )
151
+ dispose_tensor(hidden_states)
152
+
153
+ # Act
154
+ down_input = torch.empty(
155
+ (
156
+ gateup_output.shape[0],
157
+ gateup_output.shape[1],
158
+ gateup_output.shape[2] // 2,
159
+ ),
160
+ device=hidden_states_device,
161
+ dtype=torch.float8_e4m3fn,
162
+ )
163
+ scale_block_size = 128
164
+ down_input_scale = torch.empty(
165
+ (
166
+ gateup_output.shape[0],
167
+ gateup_output.shape[1],
168
+ gateup_output.shape[2] // 2 // scale_block_size,
169
+ ),
170
+ device=hidden_states_device,
171
+ dtype=torch.float32,
172
+ )
173
+ silu_and_mul_masked_post_quant_fwd(
174
+ gateup_output,
175
+ down_input,
176
+ down_input_scale,
177
+ scale_block_size,
178
+ masked_m,
179
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
180
+ )
181
+ del gateup_output
182
+
183
+ # GroupGemm-1
184
+ n = w2_weight.shape[1]
185
+
186
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
187
+ down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
188
+ down_input_scale
189
+ )
190
+
191
+ down_output = torch.empty(
192
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
193
+ )
194
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
195
+ (down_input, down_input_scale),
196
+ (w2_weight, w2_scale),
197
+ down_output,
198
+ masked_m,
199
+ expected_m,
200
+ )
201
+ del down_input
202
+
203
+ return down_output
204
+
205
+ def _run_contiguous_gemm(
206
+ self,
207
+ runner_input: DeepGemmRunnerInput,
208
+ quant_info: DeepGemmMoeQuantInfo,
209
+ running_state: dict,
210
+ ) -> torch.Tensor:
211
+ pass
212
+
213
+ @property
214
+ def runner_backend(self) -> MoeRunnerBackend:
215
+ return MoeRunnerBackend.DEEP_GEMM
216
+
217
+
218
+ @register_pre_permute("standard", "deep_gemm")
219
+ def pre_permute_standard_to_deep_gemm(
220
+ dispatch_output: StandardDispatchOutput,
221
+ quant_info: DeepGemmMoeQuantInfo,
222
+ runner_config: MoeRunnerConfig,
223
+ running_state: dict,
224
+ ) -> DeepGemmRunnerInput:
225
+ from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
226
+
227
+ hidden_states, topk_output = dispatch_output
228
+ topk_weights, topk_ids, _ = topk_output
229
+
230
+ hidden_states_shape = hidden_states.shape
231
+ hidden_states_dtype = hidden_states.dtype
232
+ hidden_states_device = hidden_states.device
233
+ hidden_states_ref = hidden_states
234
+
235
+ topk_weights, topk_ids = topk_weights, topk_ids
236
+
237
+ # PreReorder
238
+ masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
239
+ moe_ep_deepgemm_preprocess(
240
+ topk_ids,
241
+ runner_config.num_local_experts,
242
+ hidden_states,
243
+ runner_config.top_k,
244
+ quant_info.block_shape,
245
+ )
246
+ )
247
+
248
+ dispose_tensor(hidden_states_ref)
249
+
250
+ running_state["topk_ids"] = topk_ids
251
+ running_state["topk_weights"] = topk_weights
252
+ running_state["hidden_states_shape"] = hidden_states_shape
253
+ running_state["hidden_states_dtype"] = hidden_states_dtype
254
+ running_state["hidden_states_device"] = hidden_states_device
255
+ running_state["src2dst"] = src2dst
256
+
257
+ return DeepGemmRunnerInput(
258
+ hidden_states=hidden_states,
259
+ hidden_states_scale=hidden_states_scale,
260
+ masked_m=masked_m,
261
+ expected_m=expected_m,
262
+ use_masked_gemm=True,
263
+ )
264
+
265
+
266
+ @register_post_permute("deep_gemm", "standard")
267
+ def post_permute_deep_gemm_to_standard(
268
+ runner_output: DeepGemmRunnerOutput,
269
+ quant_info: DeepGemmMoeQuantInfo,
270
+ runner_config: MoeRunnerConfig,
271
+ running_state: dict,
272
+ ) -> StandardCombineInput:
273
+ from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
274
+ from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
275
+
276
+ hidden_states_shape = running_state["hidden_states_shape"]
277
+ hidden_states_dtype = running_state["hidden_states_dtype"]
278
+ hidden_states_device = running_state["hidden_states_device"]
279
+ src2dst = running_state["src2dst"]
280
+ topk_ids = running_state["topk_ids"]
281
+ topk_weights = running_state["topk_weights"]
282
+
283
+ output = torch.empty(
284
+ hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
285
+ )
286
+ post_reorder_triton_kernel[(hidden_states_shape[0],)](
287
+ runner_output.hidden_states,
288
+ output,
289
+ src2dst,
290
+ topk_ids,
291
+ topk_weights,
292
+ runner_config.top_k,
293
+ hidden_states_shape[1],
294
+ BLOCK_SIZE=512,
295
+ )
296
+
297
+ dispose_tensor(runner_output.hidden_states)
298
+
299
+ if runner_config.routed_scaling_factor is not None:
300
+ output *= runner_config.routed_scaling_factor
301
+
302
+ return StandardCombineInput(
303
+ hidden_states=output,
304
+ )
@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
9
9
  MoeRunnerConfig,
10
10
  PermuteMethodPool,
11
11
  )
12
+ from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
12
13
  from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
13
14
  from sglang.srt.layers.moe.utils import get_moe_a2a_backend
14
15
 
@@ -30,6 +31,8 @@ class MoeRunner:
30
31
 
31
32
  if runner_backend.is_triton():
32
33
  self.runner_core = TritonRunnerCore(config)
34
+ elif runner_backend.is_deep_gemm():
35
+ self.runner_core = DeepGemmRunnerCore(config)
33
36
  else:
34
37
  raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
35
38
 
@@ -51,7 +51,9 @@ elif _is_hip:
51
51
 
52
52
 
53
53
  if _is_cuda or _is_hip:
54
- from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
54
+ from sgl_kernel import ( # noqa: F401
55
+ moe_align_block_size as sgl_moe_align_block_size,
56
+ )
55
57
 
56
58
 
57
59
  @dataclass
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
4
4
  from enum import IntEnum
5
- from functools import cache
6
5
  from typing import Optional
7
6
 
8
7
  import torch
@@ -11,7 +11,7 @@ _is_hip = is_hip()
11
11
 
12
12
 
13
13
  @triton.jit
14
- def fused_moe_router_kernel(
14
+ def fused_moe_router_cudacore_kernel(
15
15
  input_ptr, # input (bs, hidden_dim)
16
16
  moe_router_weight_ptr, # input (num_experts, hidden_dim)
17
17
  topk_weights_ptr, # output (bs, topk)
@@ -114,7 +114,7 @@ def fused_moe_router_kernel(
114
114
  # assert not moe_renormalize, "moe weight renormalization not implemented"
115
115
 
116
116
 
117
- def fused_moe_router_impl(
117
+ def fused_moe_router_cudacore(
118
118
  x: torch.Tensor,
119
119
  router_weight: torch.Tensor,
120
120
  topk: int,
@@ -138,7 +138,7 @@ def fused_moe_router_impl(
138
138
  ),
139
139
  }
140
140
 
141
- fused_moe_router_kernel[(bs,)](
141
+ fused_moe_router_cudacore_kernel[(bs,)](
142
142
  x,
143
143
  router_weight,
144
144
  topk_weights,
@@ -157,7 +157,7 @@ def fused_moe_router_impl(
157
157
 
158
158
 
159
159
  @triton.jit
160
- def fused_moe_router_large_bs_kernel(
160
+ def fused_moe_router_tensorcore_kernel(
161
161
  a_ptr, # input (bs, hidden_dim)
162
162
  b_ptr, # input (num_experts, hidden_dim)
163
163
  topk_weights_ptr, # output (bs, topk)
@@ -167,12 +167,15 @@ def fused_moe_router_large_bs_kernel(
167
167
  topk: tl.constexpr, # only support topk <= 2
168
168
  moe_softcapping: tl.constexpr,
169
169
  moe_renormalize: tl.constexpr, # not supported
170
+ correction_bias_ptr,
171
+ is_correction_bias: tl.constexpr,
170
172
  K: tl.constexpr,
171
173
  BLOCK_SIZE_M: tl.constexpr,
172
174
  BLOCK_SIZE_N: tl.constexpr,
173
175
  BLOCK_SIZE_K: tl.constexpr,
174
176
  stride_am: tl.constexpr,
175
177
  stride_bn: tl.constexpr,
178
+ dp_attn_workaround_flag: tl.constexpr,
176
179
  ):
177
180
 
178
181
  # 1. get block id
@@ -217,6 +220,20 @@ def fused_moe_router_large_bs_kernel(
217
220
  exped = tl.exp(2 * logits_scaled)
218
221
  logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
219
222
 
223
+ # Add bias after softcapping
224
+ if is_correction_bias:
225
+ bias = tl.load(
226
+ correction_bias_ptr + tl.arange(0, BLOCK_SIZE_N)[None, :],
227
+ mask=expert_mask.T,
228
+ other=0.0,
229
+ )
230
+ logits_softcapped = logits_softcapped + bias
231
+
232
+ if dp_attn_workaround_flag:
233
+ logits_softcapped = tl.where(
234
+ logits_softcapped != logits_softcapped, -1e9, logits_softcapped
235
+ )
236
+
220
237
  # 5. top1
221
238
  arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
222
239
  cond_top1 = arange_block_size_n < num_experts
@@ -266,7 +283,7 @@ def fused_moe_router_large_bs_kernel(
266
283
  )
267
284
 
268
285
 
269
- def fused_moe_router_large_bs_impl(
286
+ def fused_moe_router_tensorcore(
270
287
  x: torch.Tensor,
271
288
  router_weight: torch.Tensor,
272
289
  topk: int,
@@ -274,6 +291,7 @@ def fused_moe_router_large_bs_impl(
274
291
  BLOCK_SIZE_M: int,
275
292
  BLOCK_SIZE_N: int,
276
293
  BLOCK_SIZE_K: int,
294
+ correction_bias: Optional[torch.Tensor] = None,
277
295
  ):
278
296
  assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
279
297
  bs, hidden_dim = x.shape
@@ -285,10 +303,17 @@ def fused_moe_router_large_bs_impl(
285
303
 
286
304
  topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
287
305
  topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
306
+ is_correction_bias = correction_bias is not None
288
307
 
289
308
  grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
290
309
 
291
- fused_moe_router_large_bs_kernel[grid](
310
+ # TODO(ch-wan): temporary workaround for dp attention. We should support masked
311
+ # router to skip padded tokens.
312
+ from sglang.srt.layers.dp_attention import is_dp_attention_enabled
313
+
314
+ dp_attn_workaround_flag = is_dp_attention_enabled()
315
+
316
+ fused_moe_router_tensorcore_kernel[grid](
292
317
  a_ptr=x,
293
318
  b_ptr=router_weight,
294
319
  topk_weights_ptr=topk_weights,
@@ -299,11 +324,14 @@ def fused_moe_router_large_bs_impl(
299
324
  moe_softcapping=moe_softcapping,
300
325
  moe_renormalize=False,
301
326
  K=hidden_dim,
327
+ correction_bias_ptr=correction_bias,
328
+ is_correction_bias=is_correction_bias,
302
329
  BLOCK_SIZE_M=BLOCK_SIZE_M,
303
330
  BLOCK_SIZE_N=BLOCK_SIZE_N,
304
331
  BLOCK_SIZE_K=BLOCK_SIZE_K,
305
332
  stride_am=hidden_dim,
306
333
  stride_bn=hidden_dim,
334
+ dp_attn_workaround_flag=dp_attn_workaround_flag,
307
335
  )
308
336
 
309
337
  return topk_weights, topk_ids
@@ -316,6 +344,7 @@ def fused_moe_router_shim(
316
344
  topk,
317
345
  renormalize,
318
346
  correction_bias: Optional[torch.Tensor] = None,
347
+ enable_deterministic_inference: bool = False,
319
348
  ):
320
349
  assert not renormalize
321
350
  assert (
@@ -324,16 +353,22 @@ def fused_moe_router_shim(
324
353
  )
325
354
  bs, hidden_dim = hidden_states.shape
326
355
  num_experts = gating_output.shape[0]
356
+
327
357
  BLOCK_SIZE_M = 32
328
- BLOCK_SIZE_N = 16
329
- BLOCK_SIZE_K = 256
358
+
359
+ BLOCK_SIZE_N = max(num_experts, 16)
360
+ BLOCK_SIZE_K = (
361
+ 256 if num_experts < 256 else 64
362
+ ) # if experts are large, need to use smaller k block or shared memory OOM
363
+
330
364
  if (
331
- bs >= 512
332
- and topk <= 2
333
- and num_experts <= BLOCK_SIZE_N
365
+ (bs >= 512 or num_experts > 8)
334
366
  and hidden_dim % BLOCK_SIZE_K == 0
367
+ # we keep using single kernel to avoid non-deterministic behavior
368
+ and not enable_deterministic_inference
335
369
  ):
336
- return fused_moe_router_large_bs_impl(
370
+ # if large batch size or large expert, use kernel that uses tensorcore in matmul
371
+ return fused_moe_router_tensorcore(
337
372
  x=hidden_states,
338
373
  router_weight=gating_output,
339
374
  topk=topk,
@@ -341,9 +376,11 @@ def fused_moe_router_shim(
341
376
  BLOCK_SIZE_M=BLOCK_SIZE_M,
342
377
  BLOCK_SIZE_N=BLOCK_SIZE_N,
343
378
  BLOCK_SIZE_K=BLOCK_SIZE_K,
379
+ correction_bias=correction_bias,
344
380
  )
345
381
  else:
346
- return fused_moe_router_impl(
382
+ # if smaller, use kernel that does not use tensorcore in matmul
383
+ return fused_moe_router_cudacore(
347
384
  x=hidden_states,
348
385
  router_weight=gating_output,
349
386
  topk=topk,
@@ -380,11 +417,10 @@ class FusedMoeRouter:
380
417
  renormalize=False,
381
418
  )
382
419
 
383
- def forward_vllm(
420
+ def forward_torch(
384
421
  self,
385
422
  x: torch.Tensor,
386
423
  ) -> Tuple[torch.Tensor, torch.Tensor]:
387
- # g, _ = self.router_linear.forward(x)
388
424
  g = x.float() @ self.router_linear.weight.T.float()
389
425
 
390
426
  g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
@@ -16,8 +16,14 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
16
16
  DeepEPNormalCombineInput,
17
17
  DeepEPNormalOutput,
18
18
  )
19
+ from sglang.srt.layers.moe.token_dispatcher.mooncake import (
20
+ MooncakeCombineInput,
21
+ MooncakeDispatchOutput,
22
+ MooncakeEPDispatcher,
23
+ )
19
24
  from sglang.srt.layers.moe.token_dispatcher.standard import (
20
25
  StandardCombineInput,
26
+ StandardDispatcher,
21
27
  StandardDispatchOutput,
22
28
  )
23
29
 
@@ -30,6 +36,10 @@ __all__ = [
30
36
  "DispatchOutput",
31
37
  "DispatchOutputFormat",
32
38
  "DispatchOutputChecker",
39
+ "MooncakeCombineInput",
40
+ "MooncakeDispatchOutput",
41
+ "MooncakeEPDispatcher",
42
+ "StandardDispatcher",
33
43
  "StandardDispatchOutput",
34
44
  "StandardCombineInput",
35
45
  "DeepEPConfig",
@@ -73,7 +73,7 @@ class DispatchOutputFormat(Enum):
73
73
  class DispatchOutput(Protocol):
74
74
  """Protocol for dispatch outputs in different formats."""
75
75
 
76
- # TODO: add hidden_states to the protocol
76
+ hidden_states: torch.Tensor
77
77
 
78
78
  @property
79
79
  def format(self) -> DispatchOutputFormat: ...