sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,386 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from typing import NamedTuple, Optional, Tuple
6
+
7
+ from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
8
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
9
+ from sglang.srt.layers.dp_attention import get_is_extend_in_batch
10
+ from sglang.srt.layers.moe.token_dispatcher.base import (
11
+ BaseDispatcher,
12
+ CombineInput,
13
+ CombineInputFormat,
14
+ DispatchOutput,
15
+ DispatchOutputFormat,
16
+ )
17
+ from sglang.srt.layers.moe.topk import TopKOutput
18
+ from sglang.srt.layers.moe.utils import DeepEPMode
19
+ from sglang.srt.utils import get_int_env_var
20
+
21
+ try:
22
+ from mooncake.mooncake_ep_buffer import Buffer
23
+
24
+ use_mooncake_ep = True
25
+ except ImportError:
26
+ use_mooncake_ep = False
27
+
28
+ from enum import Enum, auto
29
+
30
+ import torch
31
+ import torch.distributed as dist
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class MooncakeDispatchOutput(NamedTuple):
37
+ """Mooncake EP dispatch output."""
38
+
39
+ hidden_states: torch.Tensor
40
+ hidden_states_scale: Optional[torch.Tensor]
41
+ topk_ids: torch.Tensor
42
+ topk_weights: torch.Tensor
43
+ masked_m: torch.Tensor
44
+ expected_m: int
45
+
46
+ @property
47
+ def format(self) -> DispatchOutputFormat:
48
+ return DispatchOutputFormat.DEEPEP_LL
49
+
50
+
51
+ assert isinstance(MooncakeDispatchOutput, DispatchOutput)
52
+
53
+
54
+ class MooncakeCombineInput(NamedTuple):
55
+ """Mooncake EP combine input."""
56
+
57
+ pass
58
+
59
+ @property
60
+ def format(self) -> CombineInputFormat:
61
+ return CombineInputFormat.DEEPEP_LL
62
+
63
+
64
+ assert isinstance(MooncakeCombineInput, CombineInput)
65
+
66
+
67
+ class EPBuffer:
68
+ _buffer = None
69
+ _hidden_size: Optional[int] = None
70
+ _num_max_dispatch_tokens_per_rank: Optional[int] = None
71
+ _num_experts: Optional[int] = None
72
+
73
+ @classmethod
74
+ def get_ep_buffer(
75
+ cls,
76
+ group: dist.ProcessGroup,
77
+ hidden_size: int,
78
+ param_bytes: int,
79
+ deepep_mode: DeepEPMode,
80
+ num_max_dispatch_tokens_per_rank: int = -1,
81
+ num_experts: int = -1,
82
+ ):
83
+ if cls._buffer is not None:
84
+ return cls._buffer
85
+
86
+ cls._hidden_size = hidden_size
87
+ cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
88
+ cls._num_experts = num_experts
89
+
90
+ num_ep_buffer_bytes = 0
91
+ if deepep_mode.enable_normal():
92
+ raise NotImplementedError(
93
+ "Normal mode is not supported for Mooncake EP yet."
94
+ )
95
+ if deepep_mode.enable_low_latency():
96
+ assert num_max_dispatch_tokens_per_rank != -1
97
+ assert num_experts != -1 and num_experts % group.size() == 0
98
+ num_ep_buffer_bytes = Buffer.get_ep_buffer_size_hint(
99
+ num_max_dispatch_tokens_per_rank,
100
+ hidden_size,
101
+ group.size(),
102
+ num_experts,
103
+ )
104
+
105
+ cls._buffer = Buffer(group, num_ep_buffer_bytes)
106
+ return cls._buffer
107
+
108
+
109
+ class _MooncakeEPDispatcherImpl:
110
+ def __init__(
111
+ self,
112
+ group: torch.distributed.ProcessGroup,
113
+ router_topk: int,
114
+ permute_fusion: bool,
115
+ num_experts: int,
116
+ num_local_experts: int,
117
+ hidden_size: int,
118
+ params_dtype: torch.dtype,
119
+ return_recv_hook: bool,
120
+ deepep_mode: DeepEPMode,
121
+ ):
122
+ if not use_mooncake_ep:
123
+ raise ImportError(
124
+ "Mooncake EP is not installed. Please install Mooncake package at "
125
+ "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md "
126
+ "with EP support to run SGLang with Mooncake EP."
127
+ )
128
+ self.group = group
129
+ self.router_topk = router_topk
130
+ self.permute_fusion = permute_fusion
131
+ self.num_experts = num_experts
132
+ self.num_local_experts = num_local_experts
133
+ self.hidden_size = hidden_size
134
+ self.params_dtype = params_dtype
135
+ self.return_recv_hook = return_recv_hook
136
+ self.deepep_mode = deepep_mode
137
+
138
+ self.params_bytes = 2
139
+ self.num_max_dispatch_tokens_per_rank = get_int_env_var(
140
+ "SGLANG_MOONCAKE_EP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
141
+ )
142
+ # Mooncake EP dispatch uses FINISHED_SUM_TAG=1024
143
+ # and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
144
+ assert self.num_max_dispatch_tokens_per_rank <= 1024
145
+
146
+ self.first_execution = True
147
+ self.timeout_us = 10000000
148
+
149
+ self.active_ranks = ElasticEPStateManager.instance().active_ranks
150
+
151
+ self.handle = None
152
+
153
+ def dispatch_a(
154
+ self,
155
+ hidden_states: torch.Tensor,
156
+ topk_output: TopKOutput,
157
+ ):
158
+ topk_ids, topk_weights = topk_output.topk_ids, topk_output.topk_weights
159
+ buffer = self._get_buffer()
160
+ topk_ids = topk_ids.to(torch.int64)
161
+ expected_m = (
162
+ hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
163
+ + self.num_experts
164
+ ) // self.num_experts
165
+ hidden_states, masked_m, event, hook = self._dispatch_core(
166
+ hidden_states,
167
+ topk_ids,
168
+ use_fp8=True,
169
+ )
170
+ return (
171
+ hidden_states,
172
+ topk_ids,
173
+ topk_weights,
174
+ masked_m,
175
+ expected_m,
176
+ event,
177
+ hook,
178
+ )
179
+
180
+ def dispatch_b(
181
+ self,
182
+ hidden_states,
183
+ topk_ids,
184
+ topk_weights,
185
+ masked_m,
186
+ expected_m,
187
+ event,
188
+ hook,
189
+ ):
190
+ hook() if self.return_recv_hook else event.current_stream_wait()
191
+
192
+ get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
193
+ masked_m
194
+ )
195
+
196
+ if isinstance(hidden_states, tuple):
197
+ hidden_states, hidden_states_scale = hidden_states
198
+ else:
199
+ hidden_states_scale = None
200
+
201
+ return MooncakeDispatchOutput(
202
+ hidden_states,
203
+ hidden_states_scale,
204
+ topk_ids,
205
+ topk_weights,
206
+ masked_m,
207
+ expected_m,
208
+ )
209
+
210
+ def _dispatch_core(
211
+ self,
212
+ hidden_states: torch.Tensor,
213
+ topk_ids: torch.Tensor,
214
+ use_fp8: bool = False,
215
+ ):
216
+ buffer = self._get_buffer()
217
+ packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
218
+ buffer.dispatch(
219
+ hidden_states,
220
+ topk_ids,
221
+ self.active_ranks,
222
+ self.num_max_dispatch_tokens_per_rank,
223
+ self.num_experts,
224
+ -1 if self.first_execution else self.timeout_us,
225
+ use_fp8=use_fp8,
226
+ async_finish=not self.return_recv_hook,
227
+ return_recv_hook=self.return_recv_hook,
228
+ )
229
+ )
230
+ return packed_recv_hidden, packed_recv_count, event, hook
231
+
232
+ def combine_a(
233
+ self,
234
+ hidden_states: torch.Tensor,
235
+ topk_ids: torch.Tensor,
236
+ topk_weights: torch.Tensor,
237
+ ):
238
+ hidden_states, event, hook = self._combine_core(
239
+ hidden_states,
240
+ topk_ids,
241
+ topk_weights,
242
+ )
243
+ return hidden_states, event, hook
244
+
245
+ def combine_b(self, hidden_states, event, hook):
246
+ hook() if self.return_recv_hook else event.current_stream_wait()
247
+ return hidden_states
248
+
249
+ def _combine_core(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ topk_ids: torch.Tensor,
253
+ topk_weights: torch.Tensor,
254
+ ):
255
+ buffer = self._get_buffer()
256
+ combined_hidden_states, event, hook = buffer.combine(
257
+ hidden_states,
258
+ topk_ids,
259
+ topk_weights,
260
+ self.active_ranks,
261
+ -1 if self.first_execution else self.timeout_us,
262
+ self.handle,
263
+ async_finish=not self.return_recv_hook,
264
+ return_recv_hook=self.return_recv_hook,
265
+ )
266
+ self.first_execution = False
267
+ self.handle = None
268
+ return combined_hidden_states, event, hook
269
+
270
+ def _get_buffer(self):
271
+ return EPBuffer.get_ep_buffer(
272
+ self.group,
273
+ self.hidden_size,
274
+ self.params_bytes,
275
+ self.deepep_mode,
276
+ self.num_max_dispatch_tokens_per_rank,
277
+ self.num_experts,
278
+ )
279
+
280
+
281
+ @dataclass
282
+ class _Stage(Enum):
283
+ INITIAL = auto()
284
+ AFTER_DISPATCH_A = auto()
285
+ AFTER_DISPATCH_B = auto()
286
+ AFTER_COMBINE_A = auto()
287
+
288
+
289
+ class MooncakeEPDispatcher(BaseDispatcher):
290
+ def __init__(
291
+ self,
292
+ group: torch.distributed.ProcessGroup,
293
+ router_topk: int,
294
+ permute_fusion: bool = False,
295
+ num_experts: int = None,
296
+ num_local_experts: int = None,
297
+ hidden_size: int = None,
298
+ params_dtype: torch.dtype = None,
299
+ deepep_mode: DeepEPMode = DeepEPMode.AUTO,
300
+ async_finish: bool = False,
301
+ return_recv_hook: bool = False,
302
+ ):
303
+ self.deepep_mode = deepep_mode
304
+
305
+ if self.deepep_mode.enable_low_latency():
306
+ self._low_latency_dispatcher = _MooncakeEPDispatcherImpl(
307
+ group=group,
308
+ router_topk=router_topk,
309
+ permute_fusion=permute_fusion,
310
+ num_experts=num_experts,
311
+ num_local_experts=num_local_experts,
312
+ hidden_size=hidden_size,
313
+ params_dtype=params_dtype,
314
+ return_recv_hook=return_recv_hook,
315
+ deepep_mode=deepep_mode,
316
+ )
317
+ if self.deepep_mode.enable_normal():
318
+ raise NotImplementedError
319
+
320
+ self._stage = _Stage.INITIAL
321
+
322
+ def dispatch(self, *args, **kwargs) -> DispatchOutput:
323
+ self.dispatch_a(*args, **kwargs)
324
+ ret = self.dispatch_b()
325
+ return ret
326
+
327
+ def dispatch_a(
328
+ self,
329
+ hidden_states: torch.Tensor,
330
+ topk_output: TopKOutput,
331
+ ):
332
+ self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
333
+ inner_state = self._get_impl().dispatch_a(
334
+ hidden_states=hidden_states,
335
+ topk_output=topk_output,
336
+ )
337
+ self._dispatch_intermediate_state = inner_state
338
+
339
+ def dispatch_b(self):
340
+ self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
341
+ inner_state = self._dispatch_intermediate_state
342
+ del self._dispatch_intermediate_state
343
+ return self._get_impl().dispatch_b(*inner_state)
344
+
345
+ def combine(self, *args, **kwargs) -> Tuple:
346
+ self.combine_a(*args, **kwargs)
347
+ ret = self.combine_b()
348
+ return ret
349
+
350
+ def combine_a(
351
+ self,
352
+ hidden_states: torch.Tensor,
353
+ topk_ids: torch.Tensor,
354
+ topk_weights: torch.Tensor,
355
+ overlap_args: Optional = None,
356
+ ):
357
+ self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
358
+ inner_state = self._get_impl().combine_a(
359
+ hidden_states=hidden_states,
360
+ topk_ids=topk_ids,
361
+ topk_weights=topk_weights,
362
+ )
363
+ self._combine_intermediate_state = inner_state
364
+
365
+ def combine_b(self):
366
+ self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
367
+ inner_state = self._combine_intermediate_state
368
+ del self._combine_intermediate_state
369
+ return self._get_impl().combine_b(*inner_state)
370
+
371
+ def _get_impl(self) -> _MooncakeEPDispatcherImpl:
372
+ is_extend_in_batch = get_is_extend_in_batch()
373
+ resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch)
374
+ if resolved_deepep_mode == DeepEPMode.NORMAL:
375
+ raise NotImplementedError
376
+ elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
377
+ return self._low_latency_dispatcher
378
+ else:
379
+ raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
380
+
381
+ def _update_stage(self, old_stage, new_stage):
382
+ assert self._stage == old_stage
383
+ self._stage = new_stage
384
+
385
+ def set_quant_config(self, quant_config: dict):
386
+ pass
@@ -4,6 +4,11 @@ from typing import TYPE_CHECKING, NamedTuple
4
4
 
5
5
  import torch
6
6
 
7
+ from sglang.srt.distributed import (
8
+ get_moe_expert_parallel_rank,
9
+ get_moe_expert_parallel_world_size,
10
+ )
11
+ from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
7
12
  from sglang.srt.layers.moe.token_dispatcher.base import (
8
13
  BaseDispatcher,
9
14
  CombineInput,
@@ -11,6 +16,8 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
11
16
  DispatchOutput,
12
17
  DispatchOutputFormat,
13
18
  )
19
+ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
20
+ from sglang.srt.layers.moe.utils import get_moe_runner_backend
14
21
 
15
22
  if TYPE_CHECKING:
16
23
  from sglang.srt.layers.moe.topk import TopKOutput
@@ -45,9 +52,45 @@ assert isinstance(StandardCombineInput, CombineInput)
45
52
 
46
53
  class StandardDispatcher(BaseDispatcher):
47
54
 
55
+ def __init__(self, moe_runner_config: MoeRunnerConfig):
56
+ self.moe_ep_size = get_moe_expert_parallel_world_size()
57
+ self.enable_flashinfer_cutlass_moe = (
58
+ get_moe_runner_backend().is_flashinfer_cutlass()
59
+ )
60
+ self.num_experts = moe_runner_config.num_experts
61
+ self.num_local_experts = moe_runner_config.num_local_experts
62
+ self.moe_ep_rank = get_moe_expert_parallel_rank()
63
+ self.local_expert_mapping = None
64
+
48
65
  def dispatch(
49
66
  self, hidden_states: torch.Tensor, topk_output: TopKOutput
50
67
  ) -> DispatchOutput:
68
+
69
+ if (
70
+ self.moe_ep_size > 1
71
+ and not self.enable_flashinfer_cutlass_moe
72
+ and TopKOutputChecker.format_is_standard(topk_output)
73
+ ):
74
+ if self.local_expert_mapping is None:
75
+ self.local_expert_mapping = torch.full(
76
+ (self.num_experts,), -1, dtype=torch.int32, device="cuda"
77
+ )
78
+ self.local_expert_mapping[
79
+ self.moe_ep_rank
80
+ * self.num_local_experts : (self.moe_ep_rank + 1)
81
+ * self.num_local_experts
82
+ ] = torch.arange(
83
+ 0, self.num_local_experts, dtype=torch.int32, device="cuda"
84
+ )
85
+
86
+ if self.local_expert_mapping is not None:
87
+ if TopKOutputChecker.format_is_standard(topk_output):
88
+ topk_output = topk_output._replace(
89
+ topk_ids=self.local_expert_mapping[topk_output.topk_ids]
90
+ )
91
+ elif TopKOutputChecker.format_is_triton_kernels(topk_output):
92
+ raise NotImplementedError()
93
+
51
94
  return StandardDispatchOutput(
52
95
  hidden_states=hidden_states, topk_output=topk_output
53
96
  )
@@ -59,3 +102,6 @@ class StandardDispatcher(BaseDispatcher):
59
102
  # TODO: this branch should be removed in the future
60
103
  assert isinstance(combine_input, torch.Tensor)
61
104
  return combine_input
105
+
106
+ def set_quant_config(self, quant_config: dict):
107
+ pass
@@ -111,10 +111,10 @@ class TopKOutputChecker:
111
111
  return topk_output.format.is_standard()
112
112
 
113
113
  @staticmethod
114
- def format_is_triton_kernel(
114
+ def format_is_triton_kernels(
115
115
  topk_output: TopKOutput,
116
116
  ) -> TypeGuard[TritonKernelTopKOutput]:
117
- return topk_output.format.is_triton_kernel()
117
+ return topk_output.format.is_triton_kernels()
118
118
 
119
119
  @staticmethod
120
120
  def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
@@ -129,7 +129,7 @@ class TopKOutputFormat(Enum):
129
129
  def is_standard(self) -> bool:
130
130
  return self == TopKOutputFormat.STANDARD
131
131
 
132
- def is_triton_kernel(self) -> bool:
132
+ def is_triton_kernels(self) -> bool:
133
133
  return self == TopKOutputFormat.TRITON_KERNEL
134
134
 
135
135
  def is_bypassed(self) -> bool:
@@ -254,7 +254,7 @@ class TopK(CustomOp):
254
254
  ) -> TopKOutput:
255
255
  if self.topk_config.output_format is not None:
256
256
  output_format = self.topk_config.output_format
257
- elif get_moe_runner_backend().is_triton_kernel():
257
+ elif get_moe_runner_backend().is_triton_kernels():
258
258
  output_format = TopKOutputFormat.TRITON_KERNEL
259
259
  elif (
260
260
  should_use_flashinfer_trtllm_moe()
@@ -365,9 +365,10 @@ class TopK(CustomOp):
365
365
  def empty_topk_output(self, device: torch.device) -> TopKOutput:
366
366
  topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
367
367
  topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
368
- topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
368
+ topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device)
369
+ # FIXME: router_logits should be of size (0, num_experts)
369
370
  router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
370
- return StandardTopKOutput(topk_weights, topk_idx, router_logits)
371
+ return StandardTopKOutput(topk_weights, topk_ids, router_logits)
371
372
 
372
373
 
373
374
  # ------------------------------- TopK implementation -------------------------------------
@@ -13,6 +13,7 @@ from sglang.srt.layers.dp_attention import (
13
13
  get_attention_dp_size,
14
14
  is_dp_attention_enabled,
15
15
  )
16
+ from sglang.srt.utils import log_info_on_rank0
16
17
 
17
18
  if TYPE_CHECKING:
18
19
  from sglang.srt.server_args import ServerArgs
@@ -24,6 +25,7 @@ class MoeA2ABackend(Enum):
24
25
 
25
26
  NONE = "none"
26
27
  DEEPEP = "deepep"
28
+ MOONCAKE = "mooncake"
27
29
 
28
30
  @classmethod
29
31
  def _missing_(cls, value):
@@ -40,25 +42,33 @@ class MoeA2ABackend(Enum):
40
42
  def is_deepep(self):
41
43
  return self == MoeA2ABackend.DEEPEP
42
44
 
45
+ def is_mooncake(self):
46
+ return self == MoeA2ABackend.MOONCAKE
47
+
43
48
 
44
49
  class MoeRunnerBackend(Enum):
45
50
 
46
51
  AUTO = "auto"
52
+ DEEP_GEMM = "deep_gemm"
47
53
  TRITON = "triton"
48
- TRITON_KERNEL = "triton_kernel"
54
+ TRITON_KERNELS = "triton_kernel"
49
55
  FLASHINFER_TRTLLM = "flashinfer_trtllm"
50
56
  FLASHINFER_CUTLASS = "flashinfer_cutlass"
51
57
  FLASHINFER_MXFP4 = "flashinfer_mxfp4"
52
58
  FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
59
+ CUTLASS = "cutlass"
53
60
 
54
61
  def is_auto(self):
55
62
  return self == MoeRunnerBackend.AUTO
56
63
 
64
+ def is_deep_gemm(self):
65
+ return self == MoeRunnerBackend.DEEP_GEMM
66
+
57
67
  def is_triton(self):
58
68
  return self == MoeRunnerBackend.TRITON
59
69
 
60
- def is_triton_kernel(self):
61
- return self == MoeRunnerBackend.TRITON_KERNEL
70
+ def is_triton_kernels(self):
71
+ return self == MoeRunnerBackend.TRITON_KERNELS
62
72
 
63
73
  def is_flashinfer_trtllm(self):
64
74
  return self == MoeRunnerBackend.FLASHINFER_TRTLLM
@@ -72,6 +82,9 @@ class MoeRunnerBackend(Enum):
72
82
  def is_flashinfer_mxfp4(self):
73
83
  return self == MoeRunnerBackend.FLASHINFER_MXFP4
74
84
 
85
+ def is_cutlass(self):
86
+ return self == MoeRunnerBackend.CUTLASS
87
+
75
88
 
76
89
  class DeepEPMode(Enum):
77
90
 
@@ -139,7 +152,6 @@ def initialize_moe_config(server_args: ServerArgs):
139
152
  def get_moe_a2a_backend() -> MoeA2ABackend:
140
153
  global MOE_A2A_BACKEND
141
154
  if MOE_A2A_BACKEND is None:
142
- logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
143
155
  MOE_A2A_BACKEND = MoeA2ABackend.NONE
144
156
  return MOE_A2A_BACKEND
145
157
 
@@ -147,7 +159,10 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
147
159
  def get_moe_runner_backend() -> MoeRunnerBackend:
148
160
  global MOE_RUNNER_BACKEND
149
161
  if MOE_RUNNER_BACKEND is None:
150
- logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
162
+ log_info_on_rank0(
163
+ logger,
164
+ "MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected",
165
+ )
151
166
  MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
152
167
  return MOE_RUNNER_BACKEND
153
168
 
@@ -10,13 +10,8 @@ import torch
10
10
  try:
11
11
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
12
12
  from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
13
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
14
- CompressedTensorsW8A8Fp8MoEMethod,
15
- CompressedTensorsWNA16MoEMethod,
16
- )
17
13
  from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
18
14
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
19
- from vllm.model_executor.layers.quantization.gguf import GGUFConfig
20
15
  from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
21
16
  GPTQMarlin24Config,
22
17
  )
@@ -36,9 +31,7 @@ except ImportError as e:
36
31
 
37
32
  AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
38
33
  ExpertsInt8Config
39
- ) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = (
40
- DummyConfig
41
- )
34
+ ) = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
42
35
 
43
36
 
44
37
  from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
@@ -49,6 +42,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
49
42
  )
50
43
  from sglang.srt.layers.quantization.fp8 import Fp8Config
51
44
  from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config
45
+ from sglang.srt.layers.quantization.gguf import GGUFConfig
52
46
  from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
53
47
  from sglang.srt.layers.quantization.modelopt_quant import (
54
48
  ModelOptFp4Config,
@@ -72,12 +66,14 @@ if TYPE_CHECKING:
72
66
  BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
73
67
  "fp8": Fp8Config,
74
68
  "blockwise_int8": BlockInt8Config,
75
- "modelopt": ModelOptFp8Config,
69
+ "modelopt": ModelOptFp8Config, # Auto-detect, defaults to FP8
70
+ "modelopt_fp8": ModelOptFp8Config,
76
71
  "modelopt_fp4": ModelOptFp4Config,
77
72
  "w8a8_int8": W8A8Int8Config,
78
73
  "w8a8_fp8": W8A8Fp8Config,
79
74
  "awq": AWQConfig,
80
75
  "awq_marlin": AWQMarlinConfig,
76
+ "gguf": GGUFConfig,
81
77
  "gptq": GPTQConfig,
82
78
  "gptq_marlin": GPTQMarlinConfig,
83
79
  "moe_wna16": MoeWNA16Config,
@@ -111,7 +107,6 @@ VLLM_QUANTIZATION_METHODS = {
111
107
  "deepspeedfp": DeepSpeedFPConfig,
112
108
  "tpu_int8": Int8TpuConfig,
113
109
  "marlin": MarlinConfig,
114
- "gguf": GGUFConfig,
115
110
  "gptq_marlin_24": GPTQMarlin24Config,
116
111
  "bitsandbytes": BitsAndBytesConfig,
117
112
  "qqq": QQQConfig,
@@ -174,51 +169,3 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
174
169
  return original_isinstance(obj, classinfo)
175
170
 
176
171
  builtins.isinstance = patched_isinstance
177
-
178
-
179
- def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
180
- """
181
- Monkey patch the apply function of vllm's FusedMoEMethodBase.
182
- Convert sglang arguments to vllm arguments.
183
- """
184
- original_apply = class_obj.apply
185
- sig = inspect.signature(original_apply)
186
- param_names = list(sig.parameters.keys())
187
- has_correction_bias = "e_score_correction_bias" in param_names
188
-
189
- def new_apply(
190
- self,
191
- layer: torch.nn.Module,
192
- x: torch.Tensor,
193
- topk_output: TopKOutput,
194
- *,
195
- activation: str = "silu",
196
- apply_router_weight_on_input: bool = False,
197
- inplace: bool = True,
198
- no_combine: bool = False,
199
- routed_scaling_factor: Optional[float] = None,
200
- ):
201
- assert activation == "silu"
202
- assert inplace and not no_combine
203
-
204
- kwargs = {
205
- "self": self,
206
- "layer": layer,
207
- "x": x,
208
- "topk_output": topk_output,
209
- }
210
- return original_apply(**kwargs)
211
-
212
- setattr(class_obj, "apply", new_apply)
213
-
214
-
215
- def monkey_patch_quant_configs():
216
- """Apply all monkey patches in one place."""
217
-
218
- monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
219
- monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
220
-
221
-
222
- # Only apply monkey patches if vllm is available
223
- if VLLM_AVAILABLE:
224
- monkey_patch_quant_configs()