sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,386 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from typing import NamedTuple, Optional, Tuple
6
+
7
+ from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
8
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
9
+ from sglang.srt.layers.dp_attention import get_is_extend_in_batch
10
+ from sglang.srt.layers.moe.token_dispatcher.base import (
11
+ BaseDispatcher,
12
+ CombineInput,
13
+ CombineInputFormat,
14
+ DispatchOutput,
15
+ DispatchOutputFormat,
16
+ )
17
+ from sglang.srt.layers.moe.topk import TopKOutput
18
+ from sglang.srt.layers.moe.utils import DeepEPMode
19
+ from sglang.srt.utils import get_int_env_var
20
+
21
+ try:
22
+ from mooncake.mooncake_ep_buffer import Buffer
23
+
24
+ use_mooncake_ep = True
25
+ except ImportError:
26
+ use_mooncake_ep = False
27
+
28
+ from enum import Enum, auto
29
+
30
+ import torch
31
+ import torch.distributed as dist
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class MooncakeDispatchOutput(NamedTuple):
37
+ """Mooncake EP dispatch output."""
38
+
39
+ hidden_states: torch.Tensor
40
+ hidden_states_scale: Optional[torch.Tensor]
41
+ topk_ids: torch.Tensor
42
+ topk_weights: torch.Tensor
43
+ masked_m: torch.Tensor
44
+ expected_m: int
45
+
46
+ @property
47
+ def format(self) -> DispatchOutputFormat:
48
+ return DispatchOutputFormat.DEEPEP_LL
49
+
50
+
51
+ assert isinstance(MooncakeDispatchOutput, DispatchOutput)
52
+
53
+
54
+ class MooncakeCombineInput(NamedTuple):
55
+ """Mooncake EP combine input."""
56
+
57
+ pass
58
+
59
+ @property
60
+ def format(self) -> CombineInputFormat:
61
+ return CombineInputFormat.DEEPEP_LL
62
+
63
+
64
+ assert isinstance(MooncakeCombineInput, CombineInput)
65
+
66
+
67
+ class EPBuffer:
68
+ _buffer = None
69
+ _hidden_size: Optional[int] = None
70
+ _num_max_dispatch_tokens_per_rank: Optional[int] = None
71
+ _num_experts: Optional[int] = None
72
+
73
+ @classmethod
74
+ def get_ep_buffer(
75
+ cls,
76
+ group: dist.ProcessGroup,
77
+ hidden_size: int,
78
+ param_bytes: int,
79
+ deepep_mode: DeepEPMode,
80
+ num_max_dispatch_tokens_per_rank: int = -1,
81
+ num_experts: int = -1,
82
+ ):
83
+ if cls._buffer is not None:
84
+ return cls._buffer
85
+
86
+ cls._hidden_size = hidden_size
87
+ cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
88
+ cls._num_experts = num_experts
89
+
90
+ num_ep_buffer_bytes = 0
91
+ if deepep_mode.enable_normal():
92
+ raise NotImplementedError(
93
+ "Normal mode is not supported for Mooncake EP yet."
94
+ )
95
+ if deepep_mode.enable_low_latency():
96
+ assert num_max_dispatch_tokens_per_rank != -1
97
+ assert num_experts != -1 and num_experts % group.size() == 0
98
+ num_ep_buffer_bytes = Buffer.get_ep_buffer_size_hint(
99
+ num_max_dispatch_tokens_per_rank,
100
+ hidden_size,
101
+ group.size(),
102
+ num_experts,
103
+ )
104
+
105
+ cls._buffer = Buffer(group, num_ep_buffer_bytes)
106
+ return cls._buffer
107
+
108
+
109
+ class _MooncakeEPDispatcherImpl:
110
+ def __init__(
111
+ self,
112
+ group: torch.distributed.ProcessGroup,
113
+ router_topk: int,
114
+ permute_fusion: bool,
115
+ num_experts: int,
116
+ num_local_experts: int,
117
+ hidden_size: int,
118
+ params_dtype: torch.dtype,
119
+ return_recv_hook: bool,
120
+ deepep_mode: DeepEPMode,
121
+ ):
122
+ if not use_mooncake_ep:
123
+ raise ImportError(
124
+ "Mooncake EP is not installed. Please install Mooncake package at "
125
+ "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md "
126
+ "with EP support to run SGLang with Mooncake EP."
127
+ )
128
+ self.group = group
129
+ self.router_topk = router_topk
130
+ self.permute_fusion = permute_fusion
131
+ self.num_experts = num_experts
132
+ self.num_local_experts = num_local_experts
133
+ self.hidden_size = hidden_size
134
+ self.params_dtype = params_dtype
135
+ self.return_recv_hook = return_recv_hook
136
+ self.deepep_mode = deepep_mode
137
+
138
+ self.params_bytes = 2
139
+ self.num_max_dispatch_tokens_per_rank = get_int_env_var(
140
+ "SGLANG_MOONCAKE_EP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
141
+ )
142
+ # Mooncake EP dispatch uses FINISHED_SUM_TAG=1024
143
+ # and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
144
+ assert self.num_max_dispatch_tokens_per_rank <= 1024
145
+
146
+ self.first_execution = True
147
+ self.timeout_us = 10000000
148
+
149
+ self.active_ranks = ElasticEPStateManager.instance().active_ranks
150
+
151
+ self.handle = None
152
+
153
+ def dispatch_a(
154
+ self,
155
+ hidden_states: torch.Tensor,
156
+ topk_output: TopKOutput,
157
+ ):
158
+ topk_ids, topk_weights = topk_output.topk_ids, topk_output.topk_weights
159
+ buffer = self._get_buffer()
160
+ topk_ids = topk_ids.to(torch.int64)
161
+ expected_m = (
162
+ hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
163
+ + self.num_experts
164
+ ) // self.num_experts
165
+ hidden_states, masked_m, event, hook = self._dispatch_core(
166
+ hidden_states,
167
+ topk_ids,
168
+ use_fp8=True,
169
+ )
170
+ return (
171
+ hidden_states,
172
+ topk_ids,
173
+ topk_weights,
174
+ masked_m,
175
+ expected_m,
176
+ event,
177
+ hook,
178
+ )
179
+
180
+ def dispatch_b(
181
+ self,
182
+ hidden_states,
183
+ topk_ids,
184
+ topk_weights,
185
+ masked_m,
186
+ expected_m,
187
+ event,
188
+ hook,
189
+ ):
190
+ hook() if self.return_recv_hook else event.current_stream_wait()
191
+
192
+ get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
193
+ masked_m
194
+ )
195
+
196
+ if isinstance(hidden_states, tuple):
197
+ hidden_states, hidden_states_scale = hidden_states
198
+ else:
199
+ hidden_states_scale = None
200
+
201
+ return MooncakeDispatchOutput(
202
+ hidden_states,
203
+ hidden_states_scale,
204
+ topk_ids,
205
+ topk_weights,
206
+ masked_m,
207
+ expected_m,
208
+ )
209
+
210
+ def _dispatch_core(
211
+ self,
212
+ hidden_states: torch.Tensor,
213
+ topk_ids: torch.Tensor,
214
+ use_fp8: bool = False,
215
+ ):
216
+ buffer = self._get_buffer()
217
+ packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
218
+ buffer.dispatch(
219
+ hidden_states,
220
+ topk_ids,
221
+ self.active_ranks,
222
+ self.num_max_dispatch_tokens_per_rank,
223
+ self.num_experts,
224
+ -1 if self.first_execution else self.timeout_us,
225
+ use_fp8=use_fp8,
226
+ async_finish=not self.return_recv_hook,
227
+ return_recv_hook=self.return_recv_hook,
228
+ )
229
+ )
230
+ return packed_recv_hidden, packed_recv_count, event, hook
231
+
232
+ def combine_a(
233
+ self,
234
+ hidden_states: torch.Tensor,
235
+ topk_ids: torch.Tensor,
236
+ topk_weights: torch.Tensor,
237
+ ):
238
+ hidden_states, event, hook = self._combine_core(
239
+ hidden_states,
240
+ topk_ids,
241
+ topk_weights,
242
+ )
243
+ return hidden_states, event, hook
244
+
245
+ def combine_b(self, hidden_states, event, hook):
246
+ hook() if self.return_recv_hook else event.current_stream_wait()
247
+ return hidden_states
248
+
249
+ def _combine_core(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ topk_ids: torch.Tensor,
253
+ topk_weights: torch.Tensor,
254
+ ):
255
+ buffer = self._get_buffer()
256
+ combined_hidden_states, event, hook = buffer.combine(
257
+ hidden_states,
258
+ topk_ids,
259
+ topk_weights,
260
+ self.active_ranks,
261
+ -1 if self.first_execution else self.timeout_us,
262
+ self.handle,
263
+ async_finish=not self.return_recv_hook,
264
+ return_recv_hook=self.return_recv_hook,
265
+ )
266
+ self.first_execution = False
267
+ self.handle = None
268
+ return combined_hidden_states, event, hook
269
+
270
+ def _get_buffer(self):
271
+ return EPBuffer.get_ep_buffer(
272
+ self.group,
273
+ self.hidden_size,
274
+ self.params_bytes,
275
+ self.deepep_mode,
276
+ self.num_max_dispatch_tokens_per_rank,
277
+ self.num_experts,
278
+ )
279
+
280
+
281
+ @dataclass
282
+ class _Stage(Enum):
283
+ INITIAL = auto()
284
+ AFTER_DISPATCH_A = auto()
285
+ AFTER_DISPATCH_B = auto()
286
+ AFTER_COMBINE_A = auto()
287
+
288
+
289
+ class MooncakeEPDispatcher(BaseDispatcher):
290
+ def __init__(
291
+ self,
292
+ group: torch.distributed.ProcessGroup,
293
+ router_topk: int,
294
+ permute_fusion: bool = False,
295
+ num_experts: int = None,
296
+ num_local_experts: int = None,
297
+ hidden_size: int = None,
298
+ params_dtype: torch.dtype = None,
299
+ deepep_mode: DeepEPMode = DeepEPMode.AUTO,
300
+ async_finish: bool = False,
301
+ return_recv_hook: bool = False,
302
+ ):
303
+ self.deepep_mode = deepep_mode
304
+
305
+ if self.deepep_mode.enable_low_latency():
306
+ self._low_latency_dispatcher = _MooncakeEPDispatcherImpl(
307
+ group=group,
308
+ router_topk=router_topk,
309
+ permute_fusion=permute_fusion,
310
+ num_experts=num_experts,
311
+ num_local_experts=num_local_experts,
312
+ hidden_size=hidden_size,
313
+ params_dtype=params_dtype,
314
+ return_recv_hook=return_recv_hook,
315
+ deepep_mode=deepep_mode,
316
+ )
317
+ if self.deepep_mode.enable_normal():
318
+ raise NotImplementedError
319
+
320
+ self._stage = _Stage.INITIAL
321
+
322
+ def dispatch(self, *args, **kwargs) -> DispatchOutput:
323
+ self.dispatch_a(*args, **kwargs)
324
+ ret = self.dispatch_b()
325
+ return ret
326
+
327
+ def dispatch_a(
328
+ self,
329
+ hidden_states: torch.Tensor,
330
+ topk_output: TopKOutput,
331
+ ):
332
+ self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
333
+ inner_state = self._get_impl().dispatch_a(
334
+ hidden_states=hidden_states,
335
+ topk_output=topk_output,
336
+ )
337
+ self._dispatch_intermediate_state = inner_state
338
+
339
+ def dispatch_b(self):
340
+ self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
341
+ inner_state = self._dispatch_intermediate_state
342
+ del self._dispatch_intermediate_state
343
+ return self._get_impl().dispatch_b(*inner_state)
344
+
345
+ def combine(self, *args, **kwargs) -> Tuple:
346
+ self.combine_a(*args, **kwargs)
347
+ ret = self.combine_b()
348
+ return ret
349
+
350
+ def combine_a(
351
+ self,
352
+ hidden_states: torch.Tensor,
353
+ topk_ids: torch.Tensor,
354
+ topk_weights: torch.Tensor,
355
+ overlap_args: Optional = None,
356
+ ):
357
+ self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
358
+ inner_state = self._get_impl().combine_a(
359
+ hidden_states=hidden_states,
360
+ topk_ids=topk_ids,
361
+ topk_weights=topk_weights,
362
+ )
363
+ self._combine_intermediate_state = inner_state
364
+
365
+ def combine_b(self):
366
+ self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
367
+ inner_state = self._combine_intermediate_state
368
+ del self._combine_intermediate_state
369
+ return self._get_impl().combine_b(*inner_state)
370
+
371
+ def _get_impl(self) -> _MooncakeEPDispatcherImpl:
372
+ is_extend_in_batch = get_is_extend_in_batch()
373
+ resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch)
374
+ if resolved_deepep_mode == DeepEPMode.NORMAL:
375
+ raise NotImplementedError
376
+ elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
377
+ return self._low_latency_dispatcher
378
+ else:
379
+ raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
380
+
381
+ def _update_stage(self, old_stage, new_stage):
382
+ assert self._stage == old_stage
383
+ self._stage = new_stage
384
+
385
+ def set_quant_config(self, quant_config: dict):
386
+ pass
@@ -4,6 +4,11 @@ from typing import TYPE_CHECKING, NamedTuple
4
4
 
5
5
  import torch
6
6
 
7
+ from sglang.srt.distributed import (
8
+ get_moe_expert_parallel_rank,
9
+ get_moe_expert_parallel_world_size,
10
+ )
11
+ from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
7
12
  from sglang.srt.layers.moe.token_dispatcher.base import (
8
13
  BaseDispatcher,
9
14
  CombineInput,
@@ -11,6 +16,8 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
11
16
  DispatchOutput,
12
17
  DispatchOutputFormat,
13
18
  )
19
+ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
20
+ from sglang.srt.layers.moe.utils import get_moe_runner_backend
14
21
 
15
22
  if TYPE_CHECKING:
16
23
  from sglang.srt.layers.moe.topk import TopKOutput
@@ -45,9 +52,45 @@ assert isinstance(StandardCombineInput, CombineInput)
45
52
 
46
53
  class StandardDispatcher(BaseDispatcher):
47
54
 
55
+ def __init__(self, moe_runner_config: MoeRunnerConfig):
56
+ self.moe_ep_size = get_moe_expert_parallel_world_size()
57
+ self.enable_flashinfer_cutlass_moe = (
58
+ get_moe_runner_backend().is_flashinfer_cutlass()
59
+ )
60
+ self.num_experts = moe_runner_config.num_experts
61
+ self.num_local_experts = moe_runner_config.num_local_experts
62
+ self.moe_ep_rank = get_moe_expert_parallel_rank()
63
+ self.local_expert_mapping = None
64
+
48
65
  def dispatch(
49
66
  self, hidden_states: torch.Tensor, topk_output: TopKOutput
50
67
  ) -> DispatchOutput:
68
+
69
+ if (
70
+ self.moe_ep_size > 1
71
+ and not self.enable_flashinfer_cutlass_moe
72
+ and TopKOutputChecker.format_is_standard(topk_output)
73
+ ):
74
+ if self.local_expert_mapping is None:
75
+ self.local_expert_mapping = torch.full(
76
+ (self.num_experts,), -1, dtype=torch.int32, device="cuda"
77
+ )
78
+ self.local_expert_mapping[
79
+ self.moe_ep_rank
80
+ * self.num_local_experts : (self.moe_ep_rank + 1)
81
+ * self.num_local_experts
82
+ ] = torch.arange(
83
+ 0, self.num_local_experts, dtype=torch.int32, device="cuda"
84
+ )
85
+
86
+ if self.local_expert_mapping is not None:
87
+ if TopKOutputChecker.format_is_standard(topk_output):
88
+ topk_output = topk_output._replace(
89
+ topk_ids=self.local_expert_mapping[topk_output.topk_ids]
90
+ )
91
+ elif TopKOutputChecker.format_is_triton_kernel(topk_output):
92
+ raise NotImplementedError()
93
+
51
94
  return StandardDispatchOutput(
52
95
  hidden_states=hidden_states, topk_output=topk_output
53
96
  )
@@ -59,3 +102,6 @@ class StandardDispatcher(BaseDispatcher):
59
102
  # TODO: this branch should be removed in the future
60
103
  assert isinstance(combine_input, torch.Tensor)
61
104
  return combine_input
105
+
106
+ def set_quant_config(self, quant_config: dict):
107
+ pass
@@ -365,9 +365,10 @@ class TopK(CustomOp):
365
365
  def empty_topk_output(self, device: torch.device) -> TopKOutput:
366
366
  topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
367
367
  topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
368
- topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
368
+ topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device)
369
+ # FIXME: router_logits should be of size (0, num_experts)
369
370
  router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
370
- return StandardTopKOutput(topk_weights, topk_idx, router_logits)
371
+ return StandardTopKOutput(topk_weights, topk_ids, router_logits)
371
372
 
372
373
 
373
374
  # ------------------------------- TopK implementation -------------------------------------
@@ -13,6 +13,7 @@ from sglang.srt.layers.dp_attention import (
13
13
  get_attention_dp_size,
14
14
  is_dp_attention_enabled,
15
15
  )
16
+ from sglang.srt.utils import log_info_on_rank0
16
17
 
17
18
  if TYPE_CHECKING:
18
19
  from sglang.srt.server_args import ServerArgs
@@ -24,6 +25,7 @@ class MoeA2ABackend(Enum):
24
25
 
25
26
  NONE = "none"
26
27
  DEEPEP = "deepep"
28
+ MOONCAKE = "mooncake"
27
29
 
28
30
  @classmethod
29
31
  def _missing_(cls, value):
@@ -40,20 +42,28 @@ class MoeA2ABackend(Enum):
40
42
  def is_deepep(self):
41
43
  return self == MoeA2ABackend.DEEPEP
42
44
 
45
+ def is_mooncake(self):
46
+ return self == MoeA2ABackend.MOONCAKE
47
+
43
48
 
44
49
  class MoeRunnerBackend(Enum):
45
50
 
46
51
  AUTO = "auto"
52
+ DEEP_GEMM = "deep_gemm"
47
53
  TRITON = "triton"
48
54
  TRITON_KERNEL = "triton_kernel"
49
55
  FLASHINFER_TRTLLM = "flashinfer_trtllm"
50
56
  FLASHINFER_CUTLASS = "flashinfer_cutlass"
51
57
  FLASHINFER_MXFP4 = "flashinfer_mxfp4"
52
58
  FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
59
+ CUTLASS = "cutlass"
53
60
 
54
61
  def is_auto(self):
55
62
  return self == MoeRunnerBackend.AUTO
56
63
 
64
+ def is_deep_gemm(self):
65
+ return self == MoeRunnerBackend.DEEP_GEMM
66
+
57
67
  def is_triton(self):
58
68
  return self == MoeRunnerBackend.TRITON
59
69
 
@@ -72,6 +82,9 @@ class MoeRunnerBackend(Enum):
72
82
  def is_flashinfer_mxfp4(self):
73
83
  return self == MoeRunnerBackend.FLASHINFER_MXFP4
74
84
 
85
+ def is_cutlass(self):
86
+ return self == MoeRunnerBackend.CUTLASS
87
+
75
88
 
76
89
  class DeepEPMode(Enum):
77
90
 
@@ -108,6 +121,7 @@ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
108
121
  MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
109
122
  DEEPEP_MODE: Optional[DeepEPMode] = None
110
123
  IS_TBO_ENABLED: Optional[bool] = None
124
+ IS_SBO_ENABLED: Optional[bool] = None
111
125
  TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
112
126
  DEEPEP_CONFIG: Optional[str] = None
113
127
  DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
@@ -119,6 +133,7 @@ def initialize_moe_config(server_args: ServerArgs):
119
133
  global DEEPEP_MODE
120
134
  global DEEPEP_CONFIG
121
135
  global IS_TBO_ENABLED
136
+ global IS_SBO_ENABLED
122
137
  global TBO_TOKEN_DISTRIBUTION_THRESHOLD
123
138
  global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
124
139
 
@@ -127,6 +142,7 @@ def initialize_moe_config(server_args: ServerArgs):
127
142
  DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
128
143
  DEEPEP_CONFIG = server_args.deepep_config or ""
129
144
  IS_TBO_ENABLED = server_args.enable_two_batch_overlap
145
+ IS_SBO_ENABLED = server_args.enable_single_batch_overlap
130
146
  TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
131
147
  DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
132
148
  server_args.disable_flashinfer_cutlass_moe_fp4_allgather
@@ -144,7 +160,10 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
144
160
  def get_moe_runner_backend() -> MoeRunnerBackend:
145
161
  global MOE_RUNNER_BACKEND
146
162
  if MOE_RUNNER_BACKEND is None:
147
- logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
163
+ log_info_on_rank0(
164
+ logger,
165
+ "MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected",
166
+ )
148
167
  MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
149
168
  return MOE_RUNNER_BACKEND
150
169
 
@@ -172,6 +191,13 @@ def is_tbo_enabled() -> bool:
172
191
  return IS_TBO_ENABLED
173
192
 
174
193
 
194
+ def is_sbo_enabled() -> bool:
195
+ global IS_SBO_ENABLED
196
+ if IS_SBO_ENABLED is None:
197
+ IS_SBO_ENABLED = False
198
+ return IS_SBO_ENABLED
199
+
200
+
175
201
  def get_tbo_token_distribution_threshold() -> float:
176
202
  global TBO_TOKEN_DISTRIBUTION_THRESHOLD
177
203
  if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
@@ -7,6 +7,7 @@ from typing import Callable, Optional, Union
7
7
  import torch
8
8
  from torch.nn import Parameter
9
9
 
10
+ from sglang.srt.layers.utils import pad_or_narrow_weight
10
11
  from sglang.srt.utils import is_cpu
11
12
 
12
13
  __all__ = [
@@ -156,9 +157,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
156
157
  )
157
158
  else:
158
159
  if not use_presharded_weights:
159
- loaded_weight = loaded_weight.narrow(
160
- self.output_dim, tp_rank * shard_size, shard_size
161
- )
160
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
161
+ start_idx = tp_rank * shard_size
162
+ end_idx = start_idx + shard_size
163
+ if end_idx > loaded_weight.shape[self.output_dim]:
164
+ loaded_weight = pad_or_narrow_weight(
165
+ loaded_weight, self.output_dim, start_idx, shard_size
166
+ )
167
+ else:
168
+ loaded_weight = loaded_weight.narrow(
169
+ self.output_dim, start_idx, shard_size
170
+ )
162
171
 
163
172
  assert param_data.shape == loaded_weight.shape
164
173
  param_data.copy_(loaded_weight)
@@ -258,9 +267,17 @@ class RowvLLMParameter(BasevLLMParameter):
258
267
 
259
268
  return
260
269
  else:
261
- loaded_weight = loaded_weight.narrow(
262
- self.input_dim, tp_rank * shard_size, shard_size
263
- )
270
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
271
+ start_idx = tp_rank * shard_size
272
+ end_idx = start_idx + shard_size
273
+ if end_idx > loaded_weight.shape[self.input_dim]:
274
+ loaded_weight = pad_or_narrow_weight(
275
+ loaded_weight, self.input_dim, start_idx, shard_size
276
+ )
277
+ else:
278
+ loaded_weight = loaded_weight.narrow(
279
+ self.input_dim, start_idx, shard_size
280
+ )
264
281
 
265
282
  if len(loaded_weight.shape) == 0:
266
283
  loaded_weight = loaded_weight.reshape(1)
@@ -10,10 +10,6 @@ import torch
10
10
  try:
11
11
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
12
12
  from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
13
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
14
- CompressedTensorsW8A8Fp8MoEMethod,
15
- CompressedTensorsWNA16MoEMethod,
16
- )
17
13
  from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
18
14
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
19
15
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
@@ -72,7 +68,8 @@ if TYPE_CHECKING:
72
68
  BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
73
69
  "fp8": Fp8Config,
74
70
  "blockwise_int8": BlockInt8Config,
75
- "modelopt": ModelOptFp8Config,
71
+ "modelopt": ModelOptFp8Config, # Auto-detect, defaults to FP8
72
+ "modelopt_fp8": ModelOptFp8Config,
76
73
  "modelopt_fp4": ModelOptFp4Config,
77
74
  "w8a8_int8": W8A8Int8Config,
78
75
  "w8a8_fp8": W8A8Fp8Config,
@@ -174,51 +171,3 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
174
171
  return original_isinstance(obj, classinfo)
175
172
 
176
173
  builtins.isinstance = patched_isinstance
177
-
178
-
179
- def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
180
- """
181
- Monkey patch the apply function of vllm's FusedMoEMethodBase.
182
- Convert sglang arguments to vllm arguments.
183
- """
184
- original_apply = class_obj.apply
185
- sig = inspect.signature(original_apply)
186
- param_names = list(sig.parameters.keys())
187
- has_correction_bias = "e_score_correction_bias" in param_names
188
-
189
- def new_apply(
190
- self,
191
- layer: torch.nn.Module,
192
- x: torch.Tensor,
193
- topk_output: TopKOutput,
194
- *,
195
- activation: str = "silu",
196
- apply_router_weight_on_input: bool = False,
197
- inplace: bool = True,
198
- no_combine: bool = False,
199
- routed_scaling_factor: Optional[float] = None,
200
- ):
201
- assert activation == "silu"
202
- assert inplace and not no_combine
203
-
204
- kwargs = {
205
- "self": self,
206
- "layer": layer,
207
- "x": x,
208
- "topk_output": topk_output,
209
- }
210
- return original_apply(**kwargs)
211
-
212
- setattr(class_obj, "apply", new_apply)
213
-
214
-
215
- def monkey_patch_quant_configs():
216
- """Apply all monkey patches in one place."""
217
-
218
- monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
219
- monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
220
-
221
-
222
- # Only apply monkey patches if vllm is available
223
- if VLLM_AVAILABLE:
224
- monkey_patch_quant_configs()