sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,386 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from typing import NamedTuple, Optional, Tuple
6
+
7
+ from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
8
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
9
+ from sglang.srt.layers.dp_attention import get_is_extend_in_batch
10
+ from sglang.srt.layers.moe.token_dispatcher.base import (
11
+ BaseDispatcher,
12
+ CombineInput,
13
+ CombineInputFormat,
14
+ DispatchOutput,
15
+ DispatchOutputFormat,
16
+ )
17
+ from sglang.srt.layers.moe.topk import TopKOutput
18
+ from sglang.srt.layers.moe.utils import DeepEPMode
19
+ from sglang.srt.utils import get_int_env_var
20
+
21
+ try:
22
+ from mooncake.mooncake_ep_buffer import Buffer
23
+
24
+ use_mooncake_ep = True
25
+ except ImportError:
26
+ use_mooncake_ep = False
27
+
28
+ from enum import Enum, auto
29
+
30
+ import torch
31
+ import torch.distributed as dist
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class MooncakeDispatchOutput(NamedTuple):
37
+ """Mooncake EP dispatch output."""
38
+
39
+ hidden_states: torch.Tensor
40
+ hidden_states_scale: Optional[torch.Tensor]
41
+ topk_ids: torch.Tensor
42
+ topk_weights: torch.Tensor
43
+ masked_m: torch.Tensor
44
+ expected_m: int
45
+
46
+ @property
47
+ def format(self) -> DispatchOutputFormat:
48
+ return DispatchOutputFormat.DEEPEP_LL
49
+
50
+
51
+ assert isinstance(MooncakeDispatchOutput, DispatchOutput)
52
+
53
+
54
+ class MooncakeCombineInput(NamedTuple):
55
+ """Mooncake EP combine input."""
56
+
57
+ pass
58
+
59
+ @property
60
+ def format(self) -> CombineInputFormat:
61
+ return CombineInputFormat.DEEPEP_LL
62
+
63
+
64
+ assert isinstance(MooncakeCombineInput, CombineInput)
65
+
66
+
67
+ class EPBuffer:
68
+ _buffer = None
69
+ _hidden_size: Optional[int] = None
70
+ _num_max_dispatch_tokens_per_rank: Optional[int] = None
71
+ _num_experts: Optional[int] = None
72
+
73
+ @classmethod
74
+ def get_ep_buffer(
75
+ cls,
76
+ group: dist.ProcessGroup,
77
+ hidden_size: int,
78
+ param_bytes: int,
79
+ deepep_mode: DeepEPMode,
80
+ num_max_dispatch_tokens_per_rank: int = -1,
81
+ num_experts: int = -1,
82
+ ):
83
+ if cls._buffer is not None:
84
+ return cls._buffer
85
+
86
+ cls._hidden_size = hidden_size
87
+ cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
88
+ cls._num_experts = num_experts
89
+
90
+ num_ep_buffer_bytes = 0
91
+ if deepep_mode.enable_normal():
92
+ raise NotImplementedError(
93
+ "Normal mode is not supported for Mooncake EP yet."
94
+ )
95
+ if deepep_mode.enable_low_latency():
96
+ assert num_max_dispatch_tokens_per_rank != -1
97
+ assert num_experts != -1 and num_experts % group.size() == 0
98
+ num_ep_buffer_bytes = Buffer.get_ep_buffer_size_hint(
99
+ num_max_dispatch_tokens_per_rank,
100
+ hidden_size,
101
+ group.size(),
102
+ num_experts,
103
+ )
104
+
105
+ cls._buffer = Buffer(group, num_ep_buffer_bytes)
106
+ return cls._buffer
107
+
108
+
109
+ class _MooncakeEPDispatcherImpl:
110
+ def __init__(
111
+ self,
112
+ group: torch.distributed.ProcessGroup,
113
+ router_topk: int,
114
+ permute_fusion: bool,
115
+ num_experts: int,
116
+ num_local_experts: int,
117
+ hidden_size: int,
118
+ params_dtype: torch.dtype,
119
+ return_recv_hook: bool,
120
+ deepep_mode: DeepEPMode,
121
+ ):
122
+ if not use_mooncake_ep:
123
+ raise ImportError(
124
+ "Mooncake EP is not installed. Please install Mooncake package at "
125
+ "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md "
126
+ "with EP support to run SGLang with Mooncake EP."
127
+ )
128
+ self.group = group
129
+ self.router_topk = router_topk
130
+ self.permute_fusion = permute_fusion
131
+ self.num_experts = num_experts
132
+ self.num_local_experts = num_local_experts
133
+ self.hidden_size = hidden_size
134
+ self.params_dtype = params_dtype
135
+ self.return_recv_hook = return_recv_hook
136
+ self.deepep_mode = deepep_mode
137
+
138
+ self.params_bytes = 2
139
+ self.num_max_dispatch_tokens_per_rank = get_int_env_var(
140
+ "SGLANG_MOONCAKE_EP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
141
+ )
142
+ # Mooncake EP dispatch uses FINISHED_SUM_TAG=1024
143
+ # and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
144
+ assert self.num_max_dispatch_tokens_per_rank <= 1024
145
+
146
+ self.first_execution = True
147
+ self.timeout_us = 10000000
148
+
149
+ self.active_ranks = ElasticEPStateManager.instance().active_ranks
150
+
151
+ self.handle = None
152
+
153
+ def dispatch_a(
154
+ self,
155
+ hidden_states: torch.Tensor,
156
+ topk_output: TopKOutput,
157
+ ):
158
+ topk_ids, topk_weights = topk_output.topk_ids, topk_output.topk_weights
159
+ buffer = self._get_buffer()
160
+ topk_ids = topk_ids.to(torch.int64)
161
+ expected_m = (
162
+ hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
163
+ + self.num_experts
164
+ ) // self.num_experts
165
+ hidden_states, masked_m, event, hook = self._dispatch_core(
166
+ hidden_states,
167
+ topk_ids,
168
+ use_fp8=True,
169
+ )
170
+ return (
171
+ hidden_states,
172
+ topk_ids,
173
+ topk_weights,
174
+ masked_m,
175
+ expected_m,
176
+ event,
177
+ hook,
178
+ )
179
+
180
+ def dispatch_b(
181
+ self,
182
+ hidden_states,
183
+ topk_ids,
184
+ topk_weights,
185
+ masked_m,
186
+ expected_m,
187
+ event,
188
+ hook,
189
+ ):
190
+ hook() if self.return_recv_hook else event.current_stream_wait()
191
+
192
+ get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
193
+ masked_m
194
+ )
195
+
196
+ if isinstance(hidden_states, tuple):
197
+ hidden_states, hidden_states_scale = hidden_states
198
+ else:
199
+ hidden_states_scale = None
200
+
201
+ return MooncakeDispatchOutput(
202
+ hidden_states,
203
+ hidden_states_scale,
204
+ topk_ids,
205
+ topk_weights,
206
+ masked_m,
207
+ expected_m,
208
+ )
209
+
210
+ def _dispatch_core(
211
+ self,
212
+ hidden_states: torch.Tensor,
213
+ topk_ids: torch.Tensor,
214
+ use_fp8: bool = False,
215
+ ):
216
+ buffer = self._get_buffer()
217
+ packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
218
+ buffer.dispatch(
219
+ hidden_states,
220
+ topk_ids,
221
+ self.active_ranks,
222
+ self.num_max_dispatch_tokens_per_rank,
223
+ self.num_experts,
224
+ -1 if self.first_execution else self.timeout_us,
225
+ use_fp8=use_fp8,
226
+ async_finish=not self.return_recv_hook,
227
+ return_recv_hook=self.return_recv_hook,
228
+ )
229
+ )
230
+ return packed_recv_hidden, packed_recv_count, event, hook
231
+
232
+ def combine_a(
233
+ self,
234
+ hidden_states: torch.Tensor,
235
+ topk_ids: torch.Tensor,
236
+ topk_weights: torch.Tensor,
237
+ ):
238
+ hidden_states, event, hook = self._combine_core(
239
+ hidden_states,
240
+ topk_ids,
241
+ topk_weights,
242
+ )
243
+ return hidden_states, event, hook
244
+
245
+ def combine_b(self, hidden_states, event, hook):
246
+ hook() if self.return_recv_hook else event.current_stream_wait()
247
+ return hidden_states
248
+
249
+ def _combine_core(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ topk_ids: torch.Tensor,
253
+ topk_weights: torch.Tensor,
254
+ ):
255
+ buffer = self._get_buffer()
256
+ combined_hidden_states, event, hook = buffer.combine(
257
+ hidden_states,
258
+ topk_ids,
259
+ topk_weights,
260
+ self.active_ranks,
261
+ -1 if self.first_execution else self.timeout_us,
262
+ self.handle,
263
+ async_finish=not self.return_recv_hook,
264
+ return_recv_hook=self.return_recv_hook,
265
+ )
266
+ self.first_execution = False
267
+ self.handle = None
268
+ return combined_hidden_states, event, hook
269
+
270
+ def _get_buffer(self):
271
+ return EPBuffer.get_ep_buffer(
272
+ self.group,
273
+ self.hidden_size,
274
+ self.params_bytes,
275
+ self.deepep_mode,
276
+ self.num_max_dispatch_tokens_per_rank,
277
+ self.num_experts,
278
+ )
279
+
280
+
281
+ @dataclass
282
+ class _Stage(Enum):
283
+ INITIAL = auto()
284
+ AFTER_DISPATCH_A = auto()
285
+ AFTER_DISPATCH_B = auto()
286
+ AFTER_COMBINE_A = auto()
287
+
288
+
289
+ class MooncakeEPDispatcher(BaseDispatcher):
290
+ def __init__(
291
+ self,
292
+ group: torch.distributed.ProcessGroup,
293
+ router_topk: int,
294
+ permute_fusion: bool = False,
295
+ num_experts: int = None,
296
+ num_local_experts: int = None,
297
+ hidden_size: int = None,
298
+ params_dtype: torch.dtype = None,
299
+ deepep_mode: DeepEPMode = DeepEPMode.AUTO,
300
+ async_finish: bool = False,
301
+ return_recv_hook: bool = False,
302
+ ):
303
+ self.deepep_mode = deepep_mode
304
+
305
+ if self.deepep_mode.enable_low_latency():
306
+ self._low_latency_dispatcher = _MooncakeEPDispatcherImpl(
307
+ group=group,
308
+ router_topk=router_topk,
309
+ permute_fusion=permute_fusion,
310
+ num_experts=num_experts,
311
+ num_local_experts=num_local_experts,
312
+ hidden_size=hidden_size,
313
+ params_dtype=params_dtype,
314
+ return_recv_hook=return_recv_hook,
315
+ deepep_mode=deepep_mode,
316
+ )
317
+ if self.deepep_mode.enable_normal():
318
+ raise NotImplementedError
319
+
320
+ self._stage = _Stage.INITIAL
321
+
322
+ def dispatch(self, *args, **kwargs) -> DispatchOutput:
323
+ self.dispatch_a(*args, **kwargs)
324
+ ret = self.dispatch_b()
325
+ return ret
326
+
327
+ def dispatch_a(
328
+ self,
329
+ hidden_states: torch.Tensor,
330
+ topk_output: TopKOutput,
331
+ ):
332
+ self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
333
+ inner_state = self._get_impl().dispatch_a(
334
+ hidden_states=hidden_states,
335
+ topk_output=topk_output,
336
+ )
337
+ self._dispatch_intermediate_state = inner_state
338
+
339
+ def dispatch_b(self):
340
+ self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
341
+ inner_state = self._dispatch_intermediate_state
342
+ del self._dispatch_intermediate_state
343
+ return self._get_impl().dispatch_b(*inner_state)
344
+
345
+ def combine(self, *args, **kwargs) -> Tuple:
346
+ self.combine_a(*args, **kwargs)
347
+ ret = self.combine_b()
348
+ return ret
349
+
350
+ def combine_a(
351
+ self,
352
+ hidden_states: torch.Tensor,
353
+ topk_ids: torch.Tensor,
354
+ topk_weights: torch.Tensor,
355
+ overlap_args: Optional = None,
356
+ ):
357
+ self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
358
+ inner_state = self._get_impl().combine_a(
359
+ hidden_states=hidden_states,
360
+ topk_ids=topk_ids,
361
+ topk_weights=topk_weights,
362
+ )
363
+ self._combine_intermediate_state = inner_state
364
+
365
+ def combine_b(self):
366
+ self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
367
+ inner_state = self._combine_intermediate_state
368
+ del self._combine_intermediate_state
369
+ return self._get_impl().combine_b(*inner_state)
370
+
371
+ def _get_impl(self) -> _MooncakeEPDispatcherImpl:
372
+ is_extend_in_batch = get_is_extend_in_batch()
373
+ resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch)
374
+ if resolved_deepep_mode == DeepEPMode.NORMAL:
375
+ raise NotImplementedError
376
+ elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
377
+ return self._low_latency_dispatcher
378
+ else:
379
+ raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
380
+
381
+ def _update_stage(self, old_stage, new_stage):
382
+ assert self._stage == old_stage
383
+ self._stage = new_stage
384
+
385
+ def set_quant_config(self, quant_config: dict):
386
+ pass
@@ -4,6 +4,11 @@ from typing import TYPE_CHECKING, NamedTuple
4
4
 
5
5
  import torch
6
6
 
7
+ from sglang.srt.distributed import (
8
+ get_moe_expert_parallel_rank,
9
+ get_moe_expert_parallel_world_size,
10
+ )
11
+ from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
7
12
  from sglang.srt.layers.moe.token_dispatcher.base import (
8
13
  BaseDispatcher,
9
14
  CombineInput,
@@ -11,6 +16,8 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
11
16
  DispatchOutput,
12
17
  DispatchOutputFormat,
13
18
  )
19
+ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
20
+ from sglang.srt.layers.moe.utils import get_moe_runner_backend
14
21
 
15
22
  if TYPE_CHECKING:
16
23
  from sglang.srt.layers.moe.topk import TopKOutput
@@ -45,9 +52,45 @@ assert isinstance(StandardCombineInput, CombineInput)
45
52
 
46
53
  class StandardDispatcher(BaseDispatcher):
47
54
 
55
+ def __init__(self, moe_runner_config: MoeRunnerConfig):
56
+ self.moe_ep_size = get_moe_expert_parallel_world_size()
57
+ self.enable_flashinfer_cutlass_moe = (
58
+ get_moe_runner_backend().is_flashinfer_cutlass()
59
+ )
60
+ self.num_experts = moe_runner_config.num_experts
61
+ self.num_local_experts = moe_runner_config.num_local_experts
62
+ self.moe_ep_rank = get_moe_expert_parallel_rank()
63
+ self.local_expert_mapping = None
64
+
48
65
  def dispatch(
49
66
  self, hidden_states: torch.Tensor, topk_output: TopKOutput
50
67
  ) -> DispatchOutput:
68
+
69
+ if (
70
+ self.moe_ep_size > 1
71
+ and not self.enable_flashinfer_cutlass_moe
72
+ and TopKOutputChecker.format_is_standard(topk_output)
73
+ ):
74
+ if self.local_expert_mapping is None:
75
+ self.local_expert_mapping = torch.full(
76
+ (self.num_experts,), -1, dtype=torch.int32, device="cuda"
77
+ )
78
+ self.local_expert_mapping[
79
+ self.moe_ep_rank
80
+ * self.num_local_experts : (self.moe_ep_rank + 1)
81
+ * self.num_local_experts
82
+ ] = torch.arange(
83
+ 0, self.num_local_experts, dtype=torch.int32, device="cuda"
84
+ )
85
+
86
+ if self.local_expert_mapping is not None:
87
+ if TopKOutputChecker.format_is_standard(topk_output):
88
+ topk_output = topk_output._replace(
89
+ topk_ids=self.local_expert_mapping[topk_output.topk_ids]
90
+ )
91
+ elif TopKOutputChecker.format_is_triton_kernel(topk_output):
92
+ raise NotImplementedError()
93
+
51
94
  return StandardDispatchOutput(
52
95
  hidden_states=hidden_states, topk_output=topk_output
53
96
  )
@@ -59,3 +102,6 @@ class StandardDispatcher(BaseDispatcher):
59
102
  # TODO: this branch should be removed in the future
60
103
  assert isinstance(combine_input, torch.Tensor)
61
104
  return combine_input
105
+
106
+ def set_quant_config(self, quant_config: dict):
107
+ pass
@@ -365,9 +365,10 @@ class TopK(CustomOp):
365
365
  def empty_topk_output(self, device: torch.device) -> TopKOutput:
366
366
  topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
367
367
  topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
368
- topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
368
+ topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device)
369
+ # FIXME: router_logits should be of size (0, num_experts)
369
370
  router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
370
- return StandardTopKOutput(topk_weights, topk_idx, router_logits)
371
+ return StandardTopKOutput(topk_weights, topk_ids, router_logits)
371
372
 
372
373
 
373
374
  # ------------------------------- TopK implementation -------------------------------------
@@ -13,6 +13,7 @@ from sglang.srt.layers.dp_attention import (
13
13
  get_attention_dp_size,
14
14
  is_dp_attention_enabled,
15
15
  )
16
+ from sglang.srt.utils import log_info_on_rank0
16
17
 
17
18
  if TYPE_CHECKING:
18
19
  from sglang.srt.server_args import ServerArgs
@@ -24,6 +25,7 @@ class MoeA2ABackend(Enum):
24
25
 
25
26
  NONE = "none"
26
27
  DEEPEP = "deepep"
28
+ MOONCAKE = "mooncake"
27
29
 
28
30
  @classmethod
29
31
  def _missing_(cls, value):
@@ -40,20 +42,28 @@ class MoeA2ABackend(Enum):
40
42
  def is_deepep(self):
41
43
  return self == MoeA2ABackend.DEEPEP
42
44
 
45
+ def is_mooncake(self):
46
+ return self == MoeA2ABackend.MOONCAKE
47
+
43
48
 
44
49
  class MoeRunnerBackend(Enum):
45
50
 
46
51
  AUTO = "auto"
52
+ DEEP_GEMM = "deep_gemm"
47
53
  TRITON = "triton"
48
54
  TRITON_KERNEL = "triton_kernel"
49
55
  FLASHINFER_TRTLLM = "flashinfer_trtllm"
50
56
  FLASHINFER_CUTLASS = "flashinfer_cutlass"
51
57
  FLASHINFER_MXFP4 = "flashinfer_mxfp4"
52
58
  FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
59
+ CUTLASS = "cutlass"
53
60
 
54
61
  def is_auto(self):
55
62
  return self == MoeRunnerBackend.AUTO
56
63
 
64
+ def is_deep_gemm(self):
65
+ return self == MoeRunnerBackend.DEEP_GEMM
66
+
57
67
  def is_triton(self):
58
68
  return self == MoeRunnerBackend.TRITON
59
69
 
@@ -72,6 +82,9 @@ class MoeRunnerBackend(Enum):
72
82
  def is_flashinfer_mxfp4(self):
73
83
  return self == MoeRunnerBackend.FLASHINFER_MXFP4
74
84
 
85
+ def is_cutlass(self):
86
+ return self == MoeRunnerBackend.CUTLASS
87
+
75
88
 
76
89
  class DeepEPMode(Enum):
77
90
 
@@ -147,7 +160,10 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
147
160
  def get_moe_runner_backend() -> MoeRunnerBackend:
148
161
  global MOE_RUNNER_BACKEND
149
162
  if MOE_RUNNER_BACKEND is None:
150
- logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
163
+ log_info_on_rank0(
164
+ logger,
165
+ "MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected",
166
+ )
151
167
  MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
152
168
  return MOE_RUNNER_BACKEND
153
169
 
@@ -10,10 +10,6 @@ import torch
10
10
  try:
11
11
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
12
12
  from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
13
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
14
- CompressedTensorsW8A8Fp8MoEMethod,
15
- CompressedTensorsWNA16MoEMethod,
16
- )
17
13
  from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
18
14
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
19
15
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
@@ -72,7 +68,8 @@ if TYPE_CHECKING:
72
68
  BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
73
69
  "fp8": Fp8Config,
74
70
  "blockwise_int8": BlockInt8Config,
75
- "modelopt": ModelOptFp8Config,
71
+ "modelopt": ModelOptFp8Config, # Auto-detect, defaults to FP8
72
+ "modelopt_fp8": ModelOptFp8Config,
76
73
  "modelopt_fp4": ModelOptFp4Config,
77
74
  "w8a8_int8": W8A8Int8Config,
78
75
  "w8a8_fp8": W8A8Fp8Config,
@@ -174,51 +171,3 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
174
171
  return original_isinstance(obj, classinfo)
175
172
 
176
173
  builtins.isinstance = patched_isinstance
177
-
178
-
179
- def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
180
- """
181
- Monkey patch the apply function of vllm's FusedMoEMethodBase.
182
- Convert sglang arguments to vllm arguments.
183
- """
184
- original_apply = class_obj.apply
185
- sig = inspect.signature(original_apply)
186
- param_names = list(sig.parameters.keys())
187
- has_correction_bias = "e_score_correction_bias" in param_names
188
-
189
- def new_apply(
190
- self,
191
- layer: torch.nn.Module,
192
- x: torch.Tensor,
193
- topk_output: TopKOutput,
194
- *,
195
- activation: str = "silu",
196
- apply_router_weight_on_input: bool = False,
197
- inplace: bool = True,
198
- no_combine: bool = False,
199
- routed_scaling_factor: Optional[float] = None,
200
- ):
201
- assert activation == "silu"
202
- assert inplace and not no_combine
203
-
204
- kwargs = {
205
- "self": self,
206
- "layer": layer,
207
- "x": x,
208
- "topk_output": topk_output,
209
- }
210
- return original_apply(**kwargs)
211
-
212
- setattr(class_obj, "apply", new_apply)
213
-
214
-
215
- def monkey_patch_quant_configs():
216
- """Apply all monkey patches in one place."""
217
-
218
- monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
219
- monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
220
-
221
-
222
- # Only apply monkey patches if vllm is available
223
- if VLLM_AVAILABLE:
224
- monkey_patch_quant_configs()