sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,59 +1,35 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from contextlib import nullcontext
5
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
6
5
 
7
6
  import torch
8
- import triton
9
- import triton.language as tl
10
7
 
11
- from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
8
+ from sglang.srt import single_batch_overlap
9
+ from sglang.srt.layers import deep_gemm_wrapper
12
10
  from sglang.srt.layers.moe import (
13
11
  get_deepep_mode,
14
12
  get_moe_a2a_backend,
15
13
  get_moe_runner_backend,
16
14
  should_use_flashinfer_trtllm_moe,
17
15
  )
18
- from sglang.srt.layers.moe.ep_moe.kernels import (
19
- ep_gather,
20
- ep_scatter,
21
- moe_ep_deepgemm_preprocess,
22
- post_reorder_triton_kernel,
23
- silu_and_mul_masked_post_quant_fwd,
24
- tma_align_input_scale,
25
- )
26
16
  from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
17
+ from sglang.srt.layers.moe.token_dispatcher.deepep import (
18
+ DeepEPLLCombineInput,
19
+ DeepEPNormalCombineInput,
20
+ )
27
21
  from sglang.srt.layers.moe.topk import TopKOutput
28
- from sglang.srt.layers.quantization import deep_gemm_wrapper
29
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
23
  from sglang.srt.layers.quantization.fp8 import Fp8Config
31
- from sglang.srt.layers.quantization.fp8_kernel import (
32
- is_fp8_fnuz,
33
- sglang_per_token_group_quant_fp8,
34
- )
35
- from sglang.srt.layers.quantization.modelopt_quant import (
36
- CUTEDSL_MOE_NVFP4_DISPATCH,
37
- ModelOptNvFp4FusedMoEMethod,
38
- )
39
- from sglang.srt.managers.schedule_batch import global_server_args_dict
40
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
- from sglang.srt.offloader import get_offloader
24
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
25
+ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
42
26
  from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
43
- from sglang.srt.utils import (
44
- ceil_div,
45
- dispose_tensor,
46
- get_bool_env_var,
47
- get_int_env_var,
48
- is_cuda,
49
- is_hip,
50
- is_npu,
51
- )
27
+ from sglang.srt.utils import get_bool_env_var, is_hip, is_npu
52
28
 
53
29
  if TYPE_CHECKING:
54
30
  from sglang.srt.layers.moe.token_dispatcher import (
55
- DeepEPLLOutput,
56
- DeepEPNormalOutput,
31
+ DeepEPLLDispatchOutput,
32
+ DeepEPNormalDispatchOutput,
57
33
  DispatchOutput,
58
34
  )
59
35
 
@@ -63,7 +39,7 @@ _is_fp8_fnuz = is_fp8_fnuz()
63
39
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
64
40
 
65
41
  if not (_is_npu or _is_hip):
66
- from sgl_kernel import silu_and_mul
42
+ pass
67
43
 
68
44
  if _use_aiter:
69
45
  from aiter import ActivationType, QuantType
@@ -72,29 +48,14 @@ if _use_aiter:
72
48
  logger = logging.getLogger(__name__)
73
49
 
74
50
 
75
- # TODO(kaixih@nvidia): ideally we should merge this logic into
76
- # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
77
- @torch.compile
78
- def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
79
- temp = x.to(torch.float32).view(torch.int32)
80
- exp = torch.bitwise_right_shift(temp, 23)
81
- mant = torch.bitwise_and(temp, 0x7FFFFF)
82
- is_ru = torch.logical_and(
83
- torch.logical_and((mant > 0), (exp != 0xFE)),
84
- ~torch.logical_and((exp == 0), (mant <= 0x400000)),
85
- )
86
- exp = torch.where(is_ru, exp + 1, exp)
87
- new_x = exp.to(torch.uint8).view(torch.int)
88
- return new_x.transpose(1, 2).contiguous().transpose(1, 2)
89
-
90
-
91
- class EPMoE(FusedMoE):
51
+ class DeepEPMoE(FusedMoE):
92
52
  """
93
- MoE Expert Parallel Impl
94
-
95
-
53
+ MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
54
+ Mooncake EP shares the same class, as they expose the same interface.
96
55
  """
97
56
 
57
+ _has_printed = False
58
+
98
59
  def __init__(
99
60
  self,
100
61
  num_experts: int,
@@ -108,291 +69,50 @@ class EPMoE(FusedMoE):
108
69
  prefix: str = "",
109
70
  activation: str = "silu",
110
71
  routed_scaling_factor: Optional[float] = None,
111
- gemm1_alpha: Optional[float] = None,
112
- gemm1_clamp_limit: Optional[float] = None,
113
- with_bias: bool = False,
114
72
  ):
115
73
  super().__init__(
116
74
  num_experts=num_experts,
75
+ top_k=top_k,
117
76
  hidden_size=hidden_size,
118
77
  intermediate_size=intermediate_size,
119
- num_fused_shared_experts=num_fused_shared_experts,
120
78
  layer_id=layer_id,
121
- top_k=top_k,
79
+ num_fused_shared_experts=num_fused_shared_experts,
122
80
  params_dtype=params_dtype,
123
81
  quant_config=quant_config,
124
82
  prefix=prefix,
125
83
  activation=activation,
126
- # apply_router_weight_on_input=apply_router_weight_on_input,
127
84
  routed_scaling_factor=routed_scaling_factor,
128
- gemm1_alpha=gemm1_alpha,
129
- gemm1_clamp_limit=gemm1_clamp_limit,
130
- with_bias=with_bias,
131
85
  )
132
86
 
133
- self.intermediate_size = intermediate_size
87
+ if _use_aiter or _is_npu:
88
+ self.deprecate_flag = False
89
+ elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and isinstance(
90
+ quant_config, Fp8Config
91
+ ):
92
+ self.deprecate_flag = True
93
+ else:
94
+ self.deprecate_flag = False
95
+
96
+ if self.deprecate_flag:
97
+ return
134
98
 
135
99
  if isinstance(quant_config, Fp8Config):
136
100
  self.use_block_quant = getattr(self.quant_method, "block_quant", False)
137
- self.block_shape = (
138
- self.quant_method.quant_config.weight_block_size
139
- if self.use_block_quant
140
- else None
141
- )
142
101
  self.use_fp8_w8a8 = True
143
102
  self.fp8_dtype = torch.float8_e4m3fn
144
- self.activation_scheme = quant_config.activation_scheme
145
- else:
103
+ self.use_w4afp8 = False
104
+ elif isinstance(quant_config, W4AFp8Config):
105
+ self.use_w4afp8 = True
146
106
  self.use_fp8_w8a8 = False
147
107
  self.use_block_quant = False
148
- self.block_shape = None
149
- self.activation_scheme = None
150
-
151
- def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
152
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
153
- return self.forward_deepgemm(hidden_states, topk_output)
154
108
  else:
155
- return super().forward(hidden_states, topk_output)
156
-
157
- def forward_deepgemm(
158
- self,
159
- hidden_states: torch.Tensor,
160
- topk_output: TopKOutput,
161
- ):
162
-
163
- self.w13_weight_fp8 = (
164
- self.w13_weight,
165
- (
166
- self.w13_weight_scale_inv
167
- if self.use_block_quant
168
- else self.w13_weight_scale
169
- ),
170
- )
171
- self.w2_weight_fp8 = (
172
- self.w2_weight,
173
- self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
174
- )
175
-
176
- assert self.quant_method is not None
177
- assert self.moe_runner_config.activation == "silu"
178
-
179
- hidden_states_shape = hidden_states.shape
180
- hidden_states_dtype = hidden_states.dtype
181
- hidden_states_device = hidden_states.device
182
-
183
- topk_weights, topk_ids, _ = topk_output
184
-
185
- if not self.use_block_quant:
186
- # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
187
- scale_block_size = 128
188
- w13_weight_scale_n = 2 * (
189
- (self.intermediate_size + scale_block_size - 1) // scale_block_size
190
- )
191
- w13_weight_scale_k = (
192
- hidden_states_shape[-1] + scale_block_size - 1
193
- ) // scale_block_size
194
- w13_weight_scale = (
195
- self.w13_weight_scale.unsqueeze(1)
196
- .repeat_interleave(w13_weight_scale_n, dim=1)
197
- .unsqueeze(2)
198
- .repeat_interleave(w13_weight_scale_k, dim=2)
199
- )
200
- self.w13_weight_fp8 = (
201
- self.w13_weight,
202
- w13_weight_scale,
203
- )
204
- w2_weight_scale_n = (
205
- hidden_states_shape[-1] + scale_block_size - 1
206
- ) // scale_block_size
207
- w2_weight_scale_k = (
208
- self.intermediate_size + scale_block_size - 1
209
- ) // scale_block_size
210
- w2_weight_scale = (
211
- self.w2_weight_scale.unsqueeze(1)
212
- .repeat_interleave(w2_weight_scale_n, dim=1)
213
- .unsqueeze(2)
214
- .repeat_interleave(w2_weight_scale_k, dim=2)
215
- )
216
- self.w2_weight_fp8 = (
217
- self.w2_weight,
218
- w2_weight_scale,
219
- )
220
-
221
- # PreReorder
222
- m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
223
- moe_ep_deepgemm_preprocess(
224
- topk_ids,
225
- self.num_experts,
226
- hidden_states,
227
- self.top_k,
228
- self.start_expert_id,
229
- self.end_expert_id,
230
- self.block_shape,
231
- )
232
- )
233
-
234
- dispose_tensor(hidden_states)
235
-
236
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
237
- b, s_mn, s_k = gateup_input_scale.shape
238
- assert (
239
- s_mn % 4 == 0 and s_k % 4 == 0
240
- ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
241
-
242
- # GroupGemm-0
243
- gateup_input_fp8 = (
244
- gateup_input,
245
- (
246
- _cast_to_e8m0_with_rounding_up(gateup_input_scale)
247
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
248
- else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
249
- gateup_input_scale
250
- )
251
- ),
252
- )
253
- num_groups, m, k = gateup_input_fp8[0].size()
254
- n = self.w13_weight.size(1)
255
- gateup_output = torch.empty(
256
- (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
257
- )
258
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
259
- gateup_input_fp8,
260
- self.w13_weight_fp8,
261
- gateup_output,
262
- masked_m,
263
- expected_m,
264
- )
265
- del gateup_input
266
- del gateup_input_fp8
267
-
268
- # Act
269
- down_input = torch.empty(
270
- (
271
- gateup_output.shape[0],
272
- gateup_output.shape[1],
273
- gateup_output.shape[2] // 2,
274
- ),
275
- device=hidden_states_device,
276
- dtype=self.fp8_dtype,
277
- )
278
- scale_block_size = 128
279
- down_input_scale = torch.empty(
280
- (
281
- gateup_output.shape[0],
282
- gateup_output.shape[1],
283
- gateup_output.shape[2] // 2 // scale_block_size,
284
- ),
285
- device=hidden_states_device,
286
- dtype=torch.float32,
287
- )
288
- silu_and_mul_masked_post_quant_fwd(
289
- gateup_output,
290
- down_input,
291
- down_input_scale,
292
- scale_block_size,
293
- masked_m,
294
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
295
- )
296
- del gateup_output
297
-
298
- # GroupGemm-1
299
- n = self.w2_weight.size(1)
300
- down_input_fp8 = (
301
- down_input,
302
- (
303
- down_input_scale
304
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
305
- else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
306
- ),
307
- )
308
- down_output = torch.empty(
309
- (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
310
- )
311
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
312
- down_input_fp8,
313
- self.w2_weight_fp8,
314
- down_output,
315
- masked_m,
316
- expected_m,
317
- )
318
- del down_input
319
- del down_input_fp8
320
-
321
- # PostReorder
322
- output = torch.empty(
323
- hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
324
- )
325
- post_reorder_triton_kernel[(hidden_states_shape[0],)](
326
- down_output,
327
- output,
328
- src2dst,
329
- topk_ids,
330
- topk_weights,
331
- self.start_expert_id,
332
- self.end_expert_id,
333
- self.top_k,
334
- hidden_states_shape[1],
335
- m_max * self.start_expert_id,
336
- BLOCK_SIZE=512,
337
- )
338
- if self.moe_runner_config.routed_scaling_factor is not None:
339
- output *= self.moe_runner_config.routed_scaling_factor
340
- return output
341
-
342
-
343
- class DeepEPMoE(EPMoE):
344
- """
345
- MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
346
- """
347
-
348
- _has_printed = False
109
+ self.use_w4afp8 = False
110
+ self.use_fp8_w8a8 = False
111
+ self.use_block_quant = False
112
+ self.use_w4afp8 = False
349
113
 
350
- def __init__(
351
- self,
352
- num_experts: int,
353
- top_k: int,
354
- hidden_size: int,
355
- intermediate_size: int,
356
- layer_id: int,
357
- num_fused_shared_experts: int = 0,
358
- params_dtype: Optional[torch.dtype] = None,
359
- quant_config: Optional[QuantizationConfig] = None,
360
- prefix: str = "",
361
- activation: str = "silu",
362
- routed_scaling_factor: Optional[float] = None,
363
- ):
364
- super().__init__(
365
- num_experts=num_experts,
366
- top_k=top_k,
367
- hidden_size=hidden_size,
368
- intermediate_size=intermediate_size,
369
- layer_id=layer_id,
370
- num_fused_shared_experts=num_fused_shared_experts,
371
- params_dtype=params_dtype,
372
- quant_config=quant_config,
373
- prefix=prefix,
374
- activation=activation,
375
- routed_scaling_factor=routed_scaling_factor,
376
- )
377
114
  self.deepep_mode = get_deepep_mode()
378
115
 
379
- # TODO: move to the beginning of the file
380
- from sglang.srt.distributed.parallel_state import get_tp_group
381
- from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
382
-
383
- self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
384
- group=get_tp_group().device_group,
385
- router_topk=self.top_k,
386
- permute_fusion=True,
387
- num_experts=self.num_experts,
388
- num_local_experts=self.num_local_experts,
389
- hidden_size=hidden_size,
390
- params_dtype=params_dtype,
391
- deepep_mode=self.deepep_mode,
392
- async_finish=True, # TODO
393
- return_recv_hook=True,
394
- )
395
-
396
116
  if self.deepep_mode.enable_low_latency() and not _is_npu:
397
117
  # NPU supports low_latency deepep without deepgemm
398
118
  assert (
@@ -416,7 +136,7 @@ class DeepEPMoE(EPMoE):
416
136
  self.w13_weight,
417
137
  (
418
138
  self.w13_weight_scale_inv
419
- if self.use_block_quant
139
+ if self.use_block_quant or self.use_w4afp8
420
140
  else self.w13_weight_scale
421
141
  ),
422
142
  )
@@ -424,7 +144,7 @@ class DeepEPMoE(EPMoE):
424
144
  self.w2_weight,
425
145
  (
426
146
  self.w2_weight_scale_inv
427
- if self.use_block_quant
147
+ if self.use_block_quant or self.use_w4afp8
428
148
  else self.w2_weight_scale
429
149
  ),
430
150
  )
@@ -432,95 +152,113 @@ class DeepEPMoE(EPMoE):
432
152
  def forward(
433
153
  self,
434
154
  hidden_states: torch.Tensor,
435
- topk_idx: torch.Tensor,
436
- topk_weights: torch.Tensor,
437
- forward_batch: ForwardBatch,
155
+ topk_output: TopKOutput,
156
+ forward_shared_experts=None,
157
+ alt_stream=None,
158
+ disable_sbo=False,
438
159
  ):
439
- dispatch_output = self.dispatch(
440
- hidden_states, topk_idx, topk_weights, forward_batch
441
- )
442
- hidden_states = self.moe_impl(dispatch_output)
443
- hidden_states = self.combine(
444
- hidden_states,
445
- dispatch_output.topk_idx,
446
- dispatch_output.topk_weights,
447
- forward_batch,
160
+
161
+ if self.deprecate_flag:
162
+ assert forward_shared_experts is None
163
+ assert alt_stream is None
164
+ return super().forward(
165
+ hidden_states,
166
+ topk_output,
167
+ )
168
+
169
+ # We have to call SBO inside MoE to be compatible with hooks used in offloading
170
+ return single_batch_overlap.execute_sbo(
171
+ hidden_states=hidden_states,
172
+ topk_output=topk_output,
173
+ # SBO args
174
+ experts=self,
175
+ forward_shared_experts=forward_shared_experts,
176
+ alt_stream=alt_stream,
177
+ disable_sbo=disable_sbo,
448
178
  )
449
- return hidden_states
450
179
 
451
180
  def dispatch(
452
181
  self,
453
182
  hidden_states: torch.Tensor,
454
- topk_idx: torch.Tensor,
455
- topk_weights: torch.Tensor,
456
- forward_batch: ForwardBatch,
183
+ topk_output: TopKOutput,
457
184
  ):
458
- return self.deepep_dispatcher.dispatch(
185
+ return self.dispatcher.dispatch(
459
186
  hidden_states=hidden_states,
460
- topk_idx=topk_idx,
461
- topk_weights=topk_weights,
462
- forward_batch=forward_batch,
463
- input_global_scale=(
464
- self.w13_input_scale_quant
465
- if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
466
- and self.quant_method.enable_flashinfer_cutedsl_moe
467
- and CUTEDSL_MOE_NVFP4_DISPATCH
468
- else None
469
- ),
187
+ topk_output=topk_output,
470
188
  )
471
189
 
472
- def moe_impl(
190
+ def run_moe_core(
473
191
  self,
474
192
  dispatch_output: DispatchOutput,
475
193
  down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
476
194
  ):
195
+
196
+ if self.deprecate_flag:
197
+ assert down_gemm_overlap_args is None
198
+ return super().run_moe_core(
199
+ dispatch_output,
200
+ )
201
+
477
202
  from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
478
203
 
479
204
  if _use_aiter:
480
205
  assert DispatchOutputChecker.format_is_deepep(dispatch_output)
481
206
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
482
- return self.forward_aiter(dispatch_output)
483
- if _is_npu:
207
+ output = self.forward_aiter(dispatch_output)
208
+ elif _is_npu:
484
209
  assert DispatchOutputChecker.format_is_deepep(dispatch_output)
485
- return self.forward_npu(dispatch_output)
486
- if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
487
- assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
488
- return self.forward_deepgemm_contiguous(dispatch_output)
210
+ output = self.forward_npu(dispatch_output)
211
+ elif DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
212
+ if self.use_w4afp8:
213
+ output = self.forward_cutlass_w4afp8(dispatch_output)
214
+ else:
215
+ assert False, "forward_deepgemm_contiguous is deprecated"
489
216
  elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
490
- if get_moe_runner_backend().is_flashinfer_cutedsl():
491
- return self.forward_flashinfer_cutedsl(
217
+ if (
218
+ get_moe_runner_backend().is_flashinfer_cutedsl()
219
+ and self.quant_config.get_name() == "modelopt_fp4"
220
+ ):
221
+ output = self.forward_flashinfer_cutedsl(
492
222
  dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
493
223
  )
494
- assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
495
- return self.forward_deepgemm_masked(dispatch_output)
496
- else:
497
- raise ValueError(
498
- f"Dispatch output format {dispatch_output.format} is not supported"
499
- )
224
+ elif self.use_w4afp8:
225
+ output = self.forward_cutlass_w4afp8_masked(dispatch_output)
226
+ else:
227
+ assert False, "forward_deepgemm_masked is deprecated"
228
+
229
+ combine_input_wrapper = (
230
+ DeepEPNormalCombineInput
231
+ if DispatchOutputChecker.format_is_deepep_normal(dispatch_output)
232
+ else DeepEPLLCombineInput
233
+ )
234
+ return combine_input_wrapper(
235
+ hidden_states=output,
236
+ topk_ids=dispatch_output.topk_ids,
237
+ topk_weights=dispatch_output.topk_weights,
238
+ overlap_args=down_gemm_overlap_args,
239
+ )
500
240
 
501
241
  def combine(
502
242
  self,
503
243
  hidden_states: torch.Tensor,
504
- topk_idx: torch.Tensor,
244
+ topk_ids: torch.Tensor,
505
245
  topk_weights: torch.Tensor,
506
- forward_batch: ForwardBatch,
507
246
  overlap_args: Optional[Dict[str, Any]] = None,
508
247
  ):
509
- return self.deepep_dispatcher.combine(
248
+ return self.dispatcher.combine(
510
249
  hidden_states=hidden_states,
511
- topk_idx=topk_idx,
250
+ topk_ids=topk_ids,
512
251
  topk_weights=topk_weights,
513
- forward_batch=forward_batch,
514
252
  overlap_args=overlap_args,
515
253
  )
516
254
 
517
255
  def forward_aiter(
518
256
  self,
519
- dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
257
+ dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
520
258
  ):
521
- hidden_states, topk_idx, topk_weights = (
259
+ hidden_states, topk_ids, topk_weights = (
522
260
  dispatch_output.hidden_states,
523
- dispatch_output.topk_idx,
261
+ dispatch_output.topk_ids,
524
262
  dispatch_output.topk_weights,
525
263
  )
526
264
  if hidden_states.shape[0] == 0:
@@ -528,15 +266,15 @@ class DeepEPMoE(EPMoE):
528
266
  # in original deepep, idx == -1 meaning invalid and will not be processed.
529
267
  # aiter does not accept -1, we use a expert mask to make these idx invalid
530
268
  # (idx == num_local_experts) meaning not used in aiter fused_moe
531
- topk_idx_copy = topk_idx.to(torch.int32)
532
- topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
269
+ topk_ids_copy = topk_ids.to(torch.int32)
270
+ topk_ids_copy[topk_ids_copy == -1] = self.num_local_experts
533
271
 
534
272
  return fused_moe(
535
273
  hidden_states,
536
274
  self.w13_weight,
537
275
  self.w2_weight,
538
276
  topk_weights,
539
- topk_idx_copy,
277
+ topk_ids_copy,
540
278
  w1_scale=self.w13_weight_scale_inv,
541
279
  w2_scale=self.w2_weight_scale_inv,
542
280
  quant_type=QuantType.per_128x128,
@@ -548,251 +286,52 @@ class DeepEPMoE(EPMoE):
548
286
  expert_mask=self.expert_mask,
549
287
  )
550
288
 
551
- def forward_deepgemm_contiguous(
552
- self,
553
- dispatch_output: DeepEPNormalOutput,
554
- ):
555
- hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
556
- dispatch_output
557
- )
558
- hidden_states_fp8, hidden_states_scale = hidden_states_fp8
559
- assert self.quant_method is not None
560
- assert self.moe_runner_config.activation == "silu"
561
- if num_recv_tokens_per_expert is None:
562
- return hidden_states_fp8.bfloat16()
563
- all_tokens = sum(num_recv_tokens_per_expert)
564
- if all_tokens <= 0:
565
- return hidden_states_fp8.bfloat16()
566
- M, K = hidden_states_fp8.size()
567
- N = self.w13_weight.size(1)
568
- scale_block_size = 128
569
-
570
- # TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
571
- w13_weight_fp8 = (
572
- self.w13_weight,
573
- (
574
- self.w13_weight_scale_inv
575
- if self.use_block_quant
576
- else self.w13_weight_scale
577
- ),
578
- )
579
- w2_weight_fp8 = (
580
- self.w2_weight,
581
- (
582
- self.w2_weight_scale_inv
583
- if self.use_block_quant
584
- else self.w2_weight_scale
585
- ),
586
- )
587
-
588
- hidden_states_fp8_shape = hidden_states_fp8.shape
589
- hidden_states_fp8_device = hidden_states_fp8.device
590
- hidden_states_fp8_dtype = hidden_states_fp8.dtype
591
-
592
- input_tensor = [
593
- torch.empty(
594
- (all_tokens, K),
595
- device=hidden_states_fp8.device,
596
- dtype=hidden_states_fp8.dtype,
597
- ),
598
- (
599
- # TODO check whether need `zeros`
600
- torch.zeros(
601
- (ceil_div(K // 128, 4), all_tokens),
602
- device=hidden_states_fp8.device,
603
- dtype=torch.int,
604
- ).transpose(0, 1)
605
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
606
- else torch.empty(
607
- (all_tokens, K // 128),
608
- device=hidden_states_fp8.device,
609
- dtype=torch.float32,
610
- )
611
- ),
612
- ]
613
- m_indices = torch.empty(
614
- all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
615
- )
616
- output_index = torch.empty_like(topk_idx)
617
-
618
- if get_offloader().forbid_copy_engine_usage:
619
- num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
620
- num_recv_tokens_per_expert
621
- )
622
- else:
623
- num_recv_tokens_per_expert_gpu = torch.tensor(
624
- num_recv_tokens_per_expert,
625
- dtype=torch.int32,
626
- pin_memory=True,
627
- device="cpu",
628
- ).cuda(non_blocking=True)
629
- expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
630
-
631
- ep_scatter(
632
- hidden_states_fp8,
633
- hidden_states_scale,
634
- topk_idx,
635
- num_recv_tokens_per_expert_gpu,
636
- expert_start_loc,
637
- input_tensor[0],
638
- input_tensor[1],
639
- m_indices,
640
- output_index,
641
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
642
- )
643
- dispose_tensor(hidden_states_fp8)
644
-
645
- gateup_output = torch.empty(
646
- (all_tokens, N),
647
- device=hidden_states_fp8_device,
648
- dtype=torch.bfloat16,
649
- )
650
- if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
651
- input_tensor[1] = tma_align_input_scale(input_tensor[1])
652
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
653
- input_tensor, w13_weight_fp8, gateup_output, m_indices
654
- )
655
- del input_tensor
656
- down_input = torch.empty(
657
- (
658
- all_tokens,
659
- N // 2,
660
- ),
661
- device=gateup_output.device,
662
- dtype=torch.bfloat16,
663
- )
664
- silu_and_mul(gateup_output.view(-1, N), down_input)
665
- del gateup_output
666
- down_output = torch.empty(
667
- (all_tokens, K),
668
- device=hidden_states_fp8_device,
669
- dtype=torch.bfloat16,
670
- )
671
- down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
672
- down_input,
673
- scale_block_size,
674
- column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
675
- scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
676
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
677
- )
678
- del down_input
679
- if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
680
- down_input_scale = tma_align_input_scale(down_input_scale)
681
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
682
- (down_input_fp8, down_input_scale),
683
- w2_weight_fp8,
684
- down_output,
685
- m_indices,
686
- )
687
- del down_input_fp8, down_input_scale
688
-
689
- gather_out = torch.empty(
690
- hidden_states_fp8_shape,
691
- device=hidden_states_fp8_device,
692
- dtype=torch.bfloat16,
693
- )
694
- ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
695
-
696
- return gather_out
697
-
698
289
  def forward_flashinfer_cutedsl(
699
290
  self,
700
- dispatch_output: DeepEPLLOutput,
291
+ dispatch_output: DeepEPLLDispatchOutput,
701
292
  down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
702
293
  ):
703
- hidden_states, _, _, masked_m, _ = dispatch_output
294
+ hidden_states, hidden_states_scale, _, _, masked_m, _ = dispatch_output
704
295
  assert self.quant_method is not None
705
296
  assert self.moe_runner_config.activation == "silu"
706
297
 
707
298
  output = self.quant_method.apply_without_routing_weights(
708
299
  layer=self,
709
- x=hidden_states,
300
+ x=(hidden_states, hidden_states_scale),
710
301
  masked_m=masked_m,
711
302
  moe_runner_config=self.moe_runner_config,
712
303
  down_gemm_overlap_args=down_gemm_overlap_args,
713
304
  )
714
305
  return output
715
306
 
716
- def forward_deepgemm_masked(
307
+ def forward_cutlass_w4afp8(
717
308
  self,
718
- dispatch_output: DeepEPLLOutput,
309
+ dispatch_output: DeepEPNormalDispatchOutput,
719
310
  ):
720
- hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
721
- assert self.quant_method is not None
722
311
  assert self.moe_runner_config.activation == "silu"
723
-
724
- # GroupGemm-0
725
- num_groups, m, k = hidden_states_fp8[0].size()
726
- n = self.w13_weight.size(1)
727
- expected_m = min(expected_m, m)
728
- gateup_output = torch.empty(
729
- (num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
730
- )
731
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
732
- hidden_states_fp8,
733
- self.w13_weight_fp8,
734
- gateup_output,
735
- masked_m,
736
- expected_m,
737
- )
738
- dispose_tensor(hidden_states_fp8[0])
739
-
740
- # Act
741
- down_input = torch.empty(
742
- (
743
- gateup_output.shape[0],
744
- gateup_output.shape[1],
745
- gateup_output.shape[2] // 2,
746
- ),
747
- device=gateup_output.device,
748
- dtype=self.fp8_dtype,
749
- )
750
- scale_block_size = 128
751
- down_input_scale = torch.empty(
752
- (
753
- gateup_output.shape[0],
754
- gateup_output.shape[1],
755
- gateup_output.shape[2] // 2 // scale_block_size,
756
- ),
757
- device=gateup_output.device,
758
- dtype=torch.float32,
759
- )
760
- silu_and_mul_masked_post_quant_fwd(
761
- gateup_output,
762
- down_input,
763
- down_input_scale,
764
- scale_block_size,
765
- masked_m,
766
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
312
+ assert isinstance(self.quant_method, W4AFp8MoEMethod)
313
+ return self.quant_method.apply_deepep_normal(
314
+ layer=self,
315
+ dispatch_output=dispatch_output,
767
316
  )
768
- del gateup_output
769
317
 
770
- # GroupGemm-1
771
- n = self.w2_weight.size(1)
772
- down_input_fp8 = (
773
- down_input,
774
- (
775
- down_input_scale
776
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
777
- else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
778
- ),
779
- )
780
- down_output = torch.empty(
781
- (num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
782
- )
783
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
784
- down_input_fp8,
785
- self.w2_weight_fp8,
786
- down_output,
787
- masked_m,
788
- expected_m,
318
+ def forward_cutlass_w4afp8_masked(
319
+ self,
320
+ dispatch_output: DeepEPLLDispatchOutput,
321
+ ):
322
+ assert self.moe_runner_config.activation == "silu"
323
+ assert isinstance(self.quant_method, W4AFp8MoEMethod)
324
+ assert get_bool_env_var(
325
+ "SGLANG_DEEPEP_BF16_DISPATCH"
326
+ ), "W4AFP8 does not support FP8 dispatch; please set SGLANG_DEEPEP_BF16_DISPATCH=1."
327
+ return self.quant_method.apply_deepep_ll(
328
+ layer=self,
329
+ dispatch_output=dispatch_output,
789
330
  )
790
331
 
791
- return down_output
792
-
793
332
  def forward_npu(
794
333
  self,
795
- dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
334
+ dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
796
335
  ):
797
336
  assert self.quant_method is not None
798
337
  assert self.moe_runner_config.activation == "silu"
@@ -805,14 +344,12 @@ class DeepEPMoE(EPMoE):
805
344
  output_dtype = torch.bfloat16
806
345
  group_list_type = 1
807
346
 
808
- def _forward_normal(dispatch_output: DeepEPNormalOutput):
347
+ def _forward_normal(dispatch_output: DeepEPNormalDispatchOutput):
809
348
  if TYPE_CHECKING:
810
- assert isinstance(dispatch_output, DeepEPNormalOutput)
811
- hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
812
-
813
- if isinstance(hidden_states, tuple):
814
- per_token_scale = hidden_states[1]
815
- hidden_states = hidden_states[0]
349
+ assert isinstance(dispatch_output, DeepEPNormalDispatchOutput)
350
+ hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = (
351
+ dispatch_output
352
+ )
816
353
 
817
354
  group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
818
355
  hidden_states.device
@@ -822,7 +359,7 @@ class DeepEPMoE(EPMoE):
822
359
  hidden_states = torch_npu.npu_grouped_matmul(
823
360
  x=[hidden_states],
824
361
  weight=[self.w13_weight.permute(0, 2, 1)],
825
- # per_token_scale=[per_token_scale],
362
+ # per_token_scale=[hidden_states_scale],
826
363
  split_item=2,
827
364
  group_list_type=group_list_type,
828
365
  group_type=0,
@@ -842,7 +379,7 @@ class DeepEPMoE(EPMoE):
842
379
  )[0]
843
380
  else:
844
381
  if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
845
- hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
382
+ hidden_states, hidden_states_scale = torch_npu.npu_dynamic_quant(
846
383
  hidden_states
847
384
  )
848
385
  # gmm1: gate_up_proj
@@ -850,7 +387,7 @@ class DeepEPMoE(EPMoE):
850
387
  x=[hidden_states],
851
388
  weight=[self.w13_weight],
852
389
  scale=[self.w13_weight_scale.to(output_dtype)],
853
- per_token_scale=[per_token_scale],
390
+ per_token_scale=[hidden_states_scale],
854
391
  split_item=2,
855
392
  group_list_type=group_list_type,
856
393
  group_type=0,
@@ -879,14 +416,17 @@ class DeepEPMoE(EPMoE):
879
416
 
880
417
  return hidden_states
881
418
 
882
- def _forward_ll(dispatch_output: DeepEPLLOutput):
419
+ def _forward_ll(dispatch_output: DeepEPLLDispatchOutput):
883
420
  if TYPE_CHECKING:
884
- assert isinstance(dispatch_output, DeepEPLLOutput)
885
- hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
886
-
887
- if isinstance(hidden_states, tuple):
888
- per_token_scale = hidden_states[1]
889
- hidden_states = hidden_states[0]
421
+ assert isinstance(dispatch_output, DeepEPLLDispatchOutput)
422
+ (
423
+ hidden_states,
424
+ hidden_states_scale,
425
+ topk_ids,
426
+ topk_weights,
427
+ group_list,
428
+ _,
429
+ ) = dispatch_output
890
430
 
891
431
  group_list = group_list.to(torch.int64)
892
432
 
@@ -895,7 +435,7 @@ class DeepEPMoE(EPMoE):
895
435
  hidden_states = torch_npu.npu_grouped_matmul(
896
436
  x=[hidden_states],
897
437
  weight=[self.w13_weight.permute(0, 2, 1)],
898
- # per_token_scale=[per_token_scale],
438
+ # per_token_scale=[hidden_states_scale],
899
439
  split_item=2,
900
440
  group_list_type=group_list_type,
901
441
  group_type=0,
@@ -929,7 +469,7 @@ class DeepEPMoE(EPMoE):
929
469
  hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
930
470
  x=hidden_states,
931
471
  weight_scale=self.w13_weight_scale.to(torch.float32),
932
- activation_scale=per_token_scale,
472
+ activation_scale=hidden_states_scale,
933
473
  bias=None,
934
474
  quant_scale=None,
935
475
  quant_offset=None,
@@ -962,7 +502,7 @@ class DeepEPMoE(EPMoE):
962
502
 
963
503
 
964
504
  def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
965
- if get_moe_a2a_backend().is_deepep():
505
+ if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
966
506
  return DeepEPMoE
967
507
 
968
508
  # NEW: Direct FP4 detection (bypasses EP requirements)
@@ -988,15 +528,4 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
988
528
  return FlashInferFusedMoE
989
529
  if get_moe_runner_backend().is_flashinfer_cutlass():
990
530
  return FusedMoE
991
- if get_moe_expert_parallel_world_size() > 1:
992
- return EPMoE
993
531
  return FusedMoE
994
-
995
-
996
- def copy_list_to_gpu_no_ce(arr: List[int]):
997
- from sgl_kernel.elementwise import copy_to_gpu_no_ce
998
-
999
- tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
1000
- tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
1001
- copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
1002
- return tensor_gpu