sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from contextlib import nullcontext
5
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
6
5
 
7
6
  import torch
8
- import triton
9
- import triton.language as tl
10
7
 
11
- from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
8
+ from sglang.srt import single_batch_overlap
9
+ from sglang.srt.layers import deep_gemm_wrapper
12
10
  from sglang.srt.layers.moe import (
13
11
  get_deepep_mode,
14
12
  get_moe_a2a_backend,
@@ -18,37 +16,21 @@ from sglang.srt.layers.moe import (
18
16
  from sglang.srt.layers.moe.ep_moe.kernels import (
19
17
  ep_gather,
20
18
  ep_scatter,
21
- moe_ep_deepgemm_preprocess,
22
- post_reorder_triton_kernel,
23
19
  silu_and_mul_masked_post_quant_fwd,
24
20
  tma_align_input_scale,
25
21
  )
26
22
  from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
27
23
  from sglang.srt.layers.moe.topk import TopKOutput
28
- from sglang.srt.layers.quantization import deep_gemm_wrapper
29
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
25
  from sglang.srt.layers.quantization.fp8 import Fp8Config
31
26
  from sglang.srt.layers.quantization.fp8_kernel import (
32
27
  is_fp8_fnuz,
33
28
  sglang_per_token_group_quant_fp8,
34
29
  )
35
- from sglang.srt.layers.quantization.modelopt_quant import (
36
- CUTEDSL_MOE_NVFP4_DISPATCH,
37
- ModelOptNvFp4FusedMoEMethod,
38
- )
39
- from sglang.srt.managers.schedule_batch import global_server_args_dict
40
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
- from sglang.srt.offloader import get_offloader
30
+ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
42
31
  from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
43
- from sglang.srt.utils import (
44
- ceil_div,
45
- dispose_tensor,
46
- get_bool_env_var,
47
- get_int_env_var,
48
- is_cuda,
49
- is_hip,
50
- is_npu,
51
- )
32
+ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
33
+ from sglang.srt.utils.offloader import get_offloader
52
34
 
53
35
  if TYPE_CHECKING:
54
36
  from sglang.srt.layers.moe.token_dispatcher import (
@@ -72,29 +54,14 @@ if _use_aiter:
72
54
  logger = logging.getLogger(__name__)
73
55
 
74
56
 
75
- # TODO(kaixih@nvidia): ideally we should merge this logic into
76
- # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
77
- @torch.compile
78
- def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
79
- temp = x.to(torch.float32).view(torch.int32)
80
- exp = torch.bitwise_right_shift(temp, 23)
81
- mant = torch.bitwise_and(temp, 0x7FFFFF)
82
- is_ru = torch.logical_and(
83
- torch.logical_and((mant > 0), (exp != 0xFE)),
84
- ~torch.logical_and((exp == 0), (mant <= 0x400000)),
85
- )
86
- exp = torch.where(is_ru, exp + 1, exp)
87
- new_x = exp.to(torch.uint8).view(torch.int)
88
- return new_x.transpose(1, 2).contiguous().transpose(1, 2)
89
-
90
-
91
- class EPMoE(FusedMoE):
57
+ class DeepEPMoE(FusedMoE):
92
58
  """
93
- MoE Expert Parallel Impl
94
-
95
-
59
+ MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
60
+ Mooncake EP shares the same class, as they expose the same interface.
96
61
  """
97
62
 
63
+ _has_printed = False
64
+
98
65
  def __init__(
99
66
  self,
100
67
  num_experts: int,
@@ -108,291 +75,37 @@ class EPMoE(FusedMoE):
108
75
  prefix: str = "",
109
76
  activation: str = "silu",
110
77
  routed_scaling_factor: Optional[float] = None,
111
- gemm1_alpha: Optional[float] = None,
112
- gemm1_clamp_limit: Optional[float] = None,
113
- with_bias: bool = False,
114
78
  ):
115
79
  super().__init__(
116
80
  num_experts=num_experts,
81
+ top_k=top_k,
117
82
  hidden_size=hidden_size,
118
83
  intermediate_size=intermediate_size,
119
- num_fused_shared_experts=num_fused_shared_experts,
120
84
  layer_id=layer_id,
121
- top_k=top_k,
85
+ num_fused_shared_experts=num_fused_shared_experts,
122
86
  params_dtype=params_dtype,
123
87
  quant_config=quant_config,
124
88
  prefix=prefix,
125
89
  activation=activation,
126
- # apply_router_weight_on_input=apply_router_weight_on_input,
127
90
  routed_scaling_factor=routed_scaling_factor,
128
- gemm1_alpha=gemm1_alpha,
129
- gemm1_clamp_limit=gemm1_clamp_limit,
130
- with_bias=with_bias,
131
91
  )
132
92
 
133
- self.intermediate_size = intermediate_size
134
-
135
93
  if isinstance(quant_config, Fp8Config):
136
94
  self.use_block_quant = getattr(self.quant_method, "block_quant", False)
137
- self.block_shape = (
138
- self.quant_method.quant_config.weight_block_size
139
- if self.use_block_quant
140
- else None
141
- )
142
95
  self.use_fp8_w8a8 = True
143
96
  self.fp8_dtype = torch.float8_e4m3fn
144
- self.activation_scheme = quant_config.activation_scheme
145
- else:
97
+ self.use_w4afp8 = False
98
+ elif isinstance(quant_config, W4AFp8Config):
99
+ self.use_w4afp8 = True
146
100
  self.use_fp8_w8a8 = False
147
101
  self.use_block_quant = False
148
- self.block_shape = None
149
- self.activation_scheme = None
150
-
151
- def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
152
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
153
- return self.forward_deepgemm(hidden_states, topk_output)
154
102
  else:
155
- return super().forward(hidden_states, topk_output)
156
-
157
- def forward_deepgemm(
158
- self,
159
- hidden_states: torch.Tensor,
160
- topk_output: TopKOutput,
161
- ):
162
-
163
- self.w13_weight_fp8 = (
164
- self.w13_weight,
165
- (
166
- self.w13_weight_scale_inv
167
- if self.use_block_quant
168
- else self.w13_weight_scale
169
- ),
170
- )
171
- self.w2_weight_fp8 = (
172
- self.w2_weight,
173
- self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
174
- )
175
-
176
- assert self.quant_method is not None
177
- assert self.moe_runner_config.activation == "silu"
178
-
179
- hidden_states_shape = hidden_states.shape
180
- hidden_states_dtype = hidden_states.dtype
181
- hidden_states_device = hidden_states.device
182
-
183
- topk_weights, topk_ids, _ = topk_output
184
-
185
- if not self.use_block_quant:
186
- # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
187
- scale_block_size = 128
188
- w13_weight_scale_n = 2 * (
189
- (self.intermediate_size + scale_block_size - 1) // scale_block_size
190
- )
191
- w13_weight_scale_k = (
192
- hidden_states_shape[-1] + scale_block_size - 1
193
- ) // scale_block_size
194
- w13_weight_scale = (
195
- self.w13_weight_scale.unsqueeze(1)
196
- .repeat_interleave(w13_weight_scale_n, dim=1)
197
- .unsqueeze(2)
198
- .repeat_interleave(w13_weight_scale_k, dim=2)
199
- )
200
- self.w13_weight_fp8 = (
201
- self.w13_weight,
202
- w13_weight_scale,
203
- )
204
- w2_weight_scale_n = (
205
- hidden_states_shape[-1] + scale_block_size - 1
206
- ) // scale_block_size
207
- w2_weight_scale_k = (
208
- self.intermediate_size + scale_block_size - 1
209
- ) // scale_block_size
210
- w2_weight_scale = (
211
- self.w2_weight_scale.unsqueeze(1)
212
- .repeat_interleave(w2_weight_scale_n, dim=1)
213
- .unsqueeze(2)
214
- .repeat_interleave(w2_weight_scale_k, dim=2)
215
- )
216
- self.w2_weight_fp8 = (
217
- self.w2_weight,
218
- w2_weight_scale,
219
- )
220
-
221
- # PreReorder
222
- m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
223
- moe_ep_deepgemm_preprocess(
224
- topk_ids,
225
- self.num_experts,
226
- hidden_states,
227
- self.top_k,
228
- self.start_expert_id,
229
- self.end_expert_id,
230
- self.block_shape,
231
- )
232
- )
233
-
234
- dispose_tensor(hidden_states)
235
-
236
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
237
- b, s_mn, s_k = gateup_input_scale.shape
238
- assert (
239
- s_mn % 4 == 0 and s_k % 4 == 0
240
- ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
241
-
242
- # GroupGemm-0
243
- gateup_input_fp8 = (
244
- gateup_input,
245
- (
246
- _cast_to_e8m0_with_rounding_up(gateup_input_scale)
247
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
248
- else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
249
- gateup_input_scale
250
- )
251
- ),
252
- )
253
- num_groups, m, k = gateup_input_fp8[0].size()
254
- n = self.w13_weight.size(1)
255
- gateup_output = torch.empty(
256
- (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
257
- )
258
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
259
- gateup_input_fp8,
260
- self.w13_weight_fp8,
261
- gateup_output,
262
- masked_m,
263
- expected_m,
264
- )
265
- del gateup_input
266
- del gateup_input_fp8
267
-
268
- # Act
269
- down_input = torch.empty(
270
- (
271
- gateup_output.shape[0],
272
- gateup_output.shape[1],
273
- gateup_output.shape[2] // 2,
274
- ),
275
- device=hidden_states_device,
276
- dtype=self.fp8_dtype,
277
- )
278
- scale_block_size = 128
279
- down_input_scale = torch.empty(
280
- (
281
- gateup_output.shape[0],
282
- gateup_output.shape[1],
283
- gateup_output.shape[2] // 2 // scale_block_size,
284
- ),
285
- device=hidden_states_device,
286
- dtype=torch.float32,
287
- )
288
- silu_and_mul_masked_post_quant_fwd(
289
- gateup_output,
290
- down_input,
291
- down_input_scale,
292
- scale_block_size,
293
- masked_m,
294
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
295
- )
296
- del gateup_output
297
-
298
- # GroupGemm-1
299
- n = self.w2_weight.size(1)
300
- down_input_fp8 = (
301
- down_input,
302
- (
303
- down_input_scale
304
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
305
- else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
306
- ),
307
- )
308
- down_output = torch.empty(
309
- (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
310
- )
311
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
312
- down_input_fp8,
313
- self.w2_weight_fp8,
314
- down_output,
315
- masked_m,
316
- expected_m,
317
- )
318
- del down_input
319
- del down_input_fp8
320
-
321
- # PostReorder
322
- output = torch.empty(
323
- hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
324
- )
325
- post_reorder_triton_kernel[(hidden_states_shape[0],)](
326
- down_output,
327
- output,
328
- src2dst,
329
- topk_ids,
330
- topk_weights,
331
- self.start_expert_id,
332
- self.end_expert_id,
333
- self.top_k,
334
- hidden_states_shape[1],
335
- m_max * self.start_expert_id,
336
- BLOCK_SIZE=512,
337
- )
338
- if self.moe_runner_config.routed_scaling_factor is not None:
339
- output *= self.moe_runner_config.routed_scaling_factor
340
- return output
341
-
342
-
343
- class DeepEPMoE(EPMoE):
344
- """
345
- MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
346
- """
347
-
348
- _has_printed = False
103
+ self.use_fp8_w8a8 = False
104
+ self.use_block_quant = False
105
+ self.use_w4afp8 = False
349
106
 
350
- def __init__(
351
- self,
352
- num_experts: int,
353
- top_k: int,
354
- hidden_size: int,
355
- intermediate_size: int,
356
- layer_id: int,
357
- num_fused_shared_experts: int = 0,
358
- params_dtype: Optional[torch.dtype] = None,
359
- quant_config: Optional[QuantizationConfig] = None,
360
- prefix: str = "",
361
- activation: str = "silu",
362
- routed_scaling_factor: Optional[float] = None,
363
- ):
364
- super().__init__(
365
- num_experts=num_experts,
366
- top_k=top_k,
367
- hidden_size=hidden_size,
368
- intermediate_size=intermediate_size,
369
- layer_id=layer_id,
370
- num_fused_shared_experts=num_fused_shared_experts,
371
- params_dtype=params_dtype,
372
- quant_config=quant_config,
373
- prefix=prefix,
374
- activation=activation,
375
- routed_scaling_factor=routed_scaling_factor,
376
- )
377
107
  self.deepep_mode = get_deepep_mode()
378
108
 
379
- # TODO: move to the beginning of the file
380
- from sglang.srt.distributed.parallel_state import get_tp_group
381
- from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
382
-
383
- self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
384
- group=get_tp_group().device_group,
385
- router_topk=self.top_k,
386
- permute_fusion=True,
387
- num_experts=self.num_experts,
388
- num_local_experts=self.num_local_experts,
389
- hidden_size=hidden_size,
390
- params_dtype=params_dtype,
391
- deepep_mode=self.deepep_mode,
392
- async_finish=True, # TODO
393
- return_recv_hook=True,
394
- )
395
-
396
109
  if self.deepep_mode.enable_low_latency() and not _is_npu:
397
110
  # NPU supports low_latency deepep without deepgemm
398
111
  assert (
@@ -416,7 +129,7 @@ class DeepEPMoE(EPMoE):
416
129
  self.w13_weight,
417
130
  (
418
131
  self.w13_weight_scale_inv
419
- if self.use_block_quant
132
+ if self.use_block_quant or self.use_w4afp8
420
133
  else self.w13_weight_scale
421
134
  ),
422
135
  )
@@ -424,7 +137,7 @@ class DeepEPMoE(EPMoE):
424
137
  self.w2_weight,
425
138
  (
426
139
  self.w2_weight_scale_inv
427
- if self.use_block_quant
140
+ if self.use_block_quant or self.use_w4afp8
428
141
  else self.w2_weight_scale
429
142
  ),
430
143
  )
@@ -432,44 +145,34 @@ class DeepEPMoE(EPMoE):
432
145
  def forward(
433
146
  self,
434
147
  hidden_states: torch.Tensor,
435
- topk_idx: torch.Tensor,
436
- topk_weights: torch.Tensor,
437
- forward_batch: ForwardBatch,
148
+ topk_output: TopKOutput,
149
+ forward_shared_experts=None,
150
+ alt_stream=None,
151
+ disable_sbo=False,
438
152
  ):
439
- dispatch_output = self.dispatch(
440
- hidden_states, topk_idx, topk_weights, forward_batch
441
- )
442
- hidden_states = self.moe_impl(dispatch_output)
443
- hidden_states = self.combine(
444
- hidden_states,
445
- dispatch_output.topk_idx,
446
- dispatch_output.topk_weights,
447
- forward_batch,
153
+
154
+ # We have to call SBO inside MoE to be compatible with hooks used in offloading
155
+ return single_batch_overlap.execute_sbo(
156
+ hidden_states=hidden_states,
157
+ topk_output=topk_output,
158
+ # SBO args
159
+ experts=self,
160
+ forward_shared_experts=forward_shared_experts,
161
+ alt_stream=alt_stream,
162
+ disable_sbo=disable_sbo,
448
163
  )
449
- return hidden_states
450
164
 
451
165
  def dispatch(
452
166
  self,
453
167
  hidden_states: torch.Tensor,
454
- topk_idx: torch.Tensor,
455
- topk_weights: torch.Tensor,
456
- forward_batch: ForwardBatch,
168
+ topk_output: TopKOutput,
457
169
  ):
458
- return self.deepep_dispatcher.dispatch(
170
+ return self.dispatcher.dispatch(
459
171
  hidden_states=hidden_states,
460
- topk_idx=topk_idx,
461
- topk_weights=topk_weights,
462
- forward_batch=forward_batch,
463
- input_global_scale=(
464
- self.w13_input_scale_quant
465
- if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
466
- and self.quant_method.enable_flashinfer_cutedsl_moe
467
- and CUTEDSL_MOE_NVFP4_DISPATCH
468
- else None
469
- ),
172
+ topk_output=topk_output,
470
173
  )
471
174
 
472
- def moe_impl(
175
+ def run_moe_core(
473
176
  self,
474
177
  dispatch_output: DispatchOutput,
475
178
  down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
@@ -484,14 +187,20 @@ class DeepEPMoE(EPMoE):
484
187
  assert DispatchOutputChecker.format_is_deepep(dispatch_output)
485
188
  return self.forward_npu(dispatch_output)
486
189
  if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
190
+ if self.use_w4afp8:
191
+ return self.forward_cutlass_w4afp8(dispatch_output)
487
192
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
488
193
  return self.forward_deepgemm_contiguous(dispatch_output)
489
194
  elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
490
- if get_moe_runner_backend().is_flashinfer_cutedsl():
195
+ if (
196
+ get_moe_runner_backend().is_flashinfer_cutedsl()
197
+ and self.quant_config.get_name() == "modelopt_fp4"
198
+ ):
491
199
  return self.forward_flashinfer_cutedsl(
492
200
  dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
493
201
  )
494
202
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
203
+ assert down_gemm_overlap_args is None
495
204
  return self.forward_deepgemm_masked(dispatch_output)
496
205
  else:
497
206
  raise ValueError(
@@ -501,16 +210,14 @@ class DeepEPMoE(EPMoE):
501
210
  def combine(
502
211
  self,
503
212
  hidden_states: torch.Tensor,
504
- topk_idx: torch.Tensor,
213
+ topk_ids: torch.Tensor,
505
214
  topk_weights: torch.Tensor,
506
- forward_batch: ForwardBatch,
507
215
  overlap_args: Optional[Dict[str, Any]] = None,
508
216
  ):
509
- return self.deepep_dispatcher.combine(
217
+ return self.dispatcher.combine(
510
218
  hidden_states=hidden_states,
511
- topk_idx=topk_idx,
219
+ topk_ids=topk_ids,
512
220
  topk_weights=topk_weights,
513
- forward_batch=forward_batch,
514
221
  overlap_args=overlap_args,
515
222
  )
516
223
 
@@ -518,9 +225,9 @@ class DeepEPMoE(EPMoE):
518
225
  self,
519
226
  dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
520
227
  ):
521
- hidden_states, topk_idx, topk_weights = (
228
+ hidden_states, topk_ids, topk_weights = (
522
229
  dispatch_output.hidden_states,
523
- dispatch_output.topk_idx,
230
+ dispatch_output.topk_ids,
524
231
  dispatch_output.topk_weights,
525
232
  )
526
233
  if hidden_states.shape[0] == 0:
@@ -528,15 +235,15 @@ class DeepEPMoE(EPMoE):
528
235
  # in original deepep, idx == -1 meaning invalid and will not be processed.
529
236
  # aiter does not accept -1, we use a expert mask to make these idx invalid
530
237
  # (idx == num_local_experts) meaning not used in aiter fused_moe
531
- topk_idx_copy = topk_idx.to(torch.int32)
532
- topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
238
+ topk_ids_copy = topk_ids.to(torch.int32)
239
+ topk_ids_copy[topk_ids_copy == -1] = self.num_local_experts
533
240
 
534
241
  return fused_moe(
535
242
  hidden_states,
536
243
  self.w13_weight,
537
244
  self.w2_weight,
538
245
  topk_weights,
539
- topk_idx_copy,
246
+ topk_ids_copy,
540
247
  w1_scale=self.w13_weight_scale_inv,
541
248
  w2_scale=self.w2_weight_scale_inv,
542
249
  quant_type=QuantType.per_128x128,
@@ -552,22 +259,24 @@ class DeepEPMoE(EPMoE):
552
259
  self,
553
260
  dispatch_output: DeepEPNormalOutput,
554
261
  ):
555
- hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
556
- dispatch_output
557
- )
558
- hidden_states_fp8, hidden_states_scale = hidden_states_fp8
262
+ (
263
+ hidden_states,
264
+ hidden_states_scale,
265
+ topk_ids,
266
+ topk_weights,
267
+ num_recv_tokens_per_expert,
268
+ ) = dispatch_output
559
269
  assert self.quant_method is not None
560
270
  assert self.moe_runner_config.activation == "silu"
561
271
  if num_recv_tokens_per_expert is None:
562
- return hidden_states_fp8.bfloat16()
272
+ return hidden_states.bfloat16()
563
273
  all_tokens = sum(num_recv_tokens_per_expert)
564
274
  if all_tokens <= 0:
565
- return hidden_states_fp8.bfloat16()
566
- M, K = hidden_states_fp8.size()
275
+ return hidden_states.bfloat16()
276
+ M, K = hidden_states.size()
567
277
  N = self.w13_weight.size(1)
568
278
  scale_block_size = 128
569
279
 
570
- # TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
571
280
  w13_weight_fp8 = (
572
281
  self.w13_weight,
573
282
  (
@@ -585,35 +294,35 @@ class DeepEPMoE(EPMoE):
585
294
  ),
586
295
  )
587
296
 
588
- hidden_states_fp8_shape = hidden_states_fp8.shape
589
- hidden_states_fp8_device = hidden_states_fp8.device
590
- hidden_states_fp8_dtype = hidden_states_fp8.dtype
297
+ hidden_states_shape = hidden_states.shape
298
+ hidden_states_device = hidden_states.device
299
+ hidden_states_dtype = hidden_states.dtype
591
300
 
592
301
  input_tensor = [
593
302
  torch.empty(
594
303
  (all_tokens, K),
595
- device=hidden_states_fp8.device,
596
- dtype=hidden_states_fp8.dtype,
304
+ device=hidden_states.device,
305
+ dtype=hidden_states.dtype,
597
306
  ),
598
307
  (
599
308
  # TODO check whether need `zeros`
600
309
  torch.zeros(
601
310
  (ceil_div(K // 128, 4), all_tokens),
602
- device=hidden_states_fp8.device,
311
+ device=hidden_states.device,
603
312
  dtype=torch.int,
604
313
  ).transpose(0, 1)
605
314
  if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
606
315
  else torch.empty(
607
316
  (all_tokens, K // 128),
608
- device=hidden_states_fp8.device,
317
+ device=hidden_states.device,
609
318
  dtype=torch.float32,
610
319
  )
611
320
  ),
612
321
  ]
613
322
  m_indices = torch.empty(
614
- all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
323
+ all_tokens, device=hidden_states.device, dtype=torch.int32
615
324
  )
616
- output_index = torch.empty_like(topk_idx)
325
+ output_index = torch.empty_like(topk_ids)
617
326
 
618
327
  if get_offloader().forbid_copy_engine_usage:
619
328
  num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
@@ -629,9 +338,9 @@ class DeepEPMoE(EPMoE):
629
338
  expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
630
339
 
631
340
  ep_scatter(
632
- hidden_states_fp8,
341
+ hidden_states,
633
342
  hidden_states_scale,
634
- topk_idx,
343
+ topk_ids,
635
344
  num_recv_tokens_per_expert_gpu,
636
345
  expert_start_loc,
637
346
  input_tensor[0],
@@ -640,11 +349,11 @@ class DeepEPMoE(EPMoE):
640
349
  output_index,
641
350
  scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
642
351
  )
643
- dispose_tensor(hidden_states_fp8)
352
+ dispose_tensor(hidden_states)
644
353
 
645
354
  gateup_output = torch.empty(
646
355
  (all_tokens, N),
647
- device=hidden_states_fp8_device,
356
+ device=hidden_states_device,
648
357
  dtype=torch.bfloat16,
649
358
  )
650
359
  if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
@@ -665,7 +374,7 @@ class DeepEPMoE(EPMoE):
665
374
  del gateup_output
666
375
  down_output = torch.empty(
667
376
  (all_tokens, K),
668
- device=hidden_states_fp8_device,
377
+ device=hidden_states_device,
669
378
  dtype=torch.bfloat16,
670
379
  )
671
380
  down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
@@ -687,11 +396,11 @@ class DeepEPMoE(EPMoE):
687
396
  del down_input_fp8, down_input_scale
688
397
 
689
398
  gather_out = torch.empty(
690
- hidden_states_fp8_shape,
691
- device=hidden_states_fp8_device,
399
+ hidden_states_shape,
400
+ device=hidden_states_device,
692
401
  dtype=torch.bfloat16,
693
402
  )
694
- ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
403
+ ep_gather(down_output, topk_ids, topk_weights, output_index, gather_out)
695
404
 
696
405
  return gather_out
697
406
 
@@ -700,42 +409,56 @@ class DeepEPMoE(EPMoE):
700
409
  dispatch_output: DeepEPLLOutput,
701
410
  down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
702
411
  ):
703
- hidden_states, _, _, masked_m, _ = dispatch_output
412
+ hidden_states, hidden_states_scale, _, _, masked_m, _ = dispatch_output
704
413
  assert self.quant_method is not None
705
414
  assert self.moe_runner_config.activation == "silu"
706
415
 
707
416
  output = self.quant_method.apply_without_routing_weights(
708
417
  layer=self,
709
- x=hidden_states,
418
+ x=(hidden_states, hidden_states_scale),
710
419
  masked_m=masked_m,
711
420
  moe_runner_config=self.moe_runner_config,
712
421
  down_gemm_overlap_args=down_gemm_overlap_args,
713
422
  )
714
423
  return output
715
424
 
425
+ def forward_cutlass_w4afp8(
426
+ self,
427
+ dispatch_output: DeepEPNormalOutput,
428
+ ):
429
+ assert self.moe_runner_config.activation == "silu"
430
+ assert isinstance(self.quant_method, W4AFp8MoEMethod)
431
+ return self.quant_method.apply_deepep_normal(
432
+ layer=self,
433
+ dispatch_output=dispatch_output,
434
+ )
435
+
716
436
  def forward_deepgemm_masked(
717
437
  self,
718
438
  dispatch_output: DeepEPLLOutput,
719
439
  ):
720
- hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
440
+ hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
721
441
  assert self.quant_method is not None
722
442
  assert self.moe_runner_config.activation == "silu"
443
+ assert (
444
+ hidden_states_scale.dtype == torch.float32
445
+ ), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}"
723
446
 
724
447
  # GroupGemm-0
725
- num_groups, m, k = hidden_states_fp8[0].size()
448
+ num_groups, m, k = hidden_states.size()
726
449
  n = self.w13_weight.size(1)
727
450
  expected_m = min(expected_m, m)
728
451
  gateup_output = torch.empty(
729
- (num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
452
+ (num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16
730
453
  )
731
454
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
732
- hidden_states_fp8,
455
+ (hidden_states, hidden_states_scale),
733
456
  self.w13_weight_fp8,
734
457
  gateup_output,
735
458
  masked_m,
736
459
  expected_m,
737
460
  )
738
- dispose_tensor(hidden_states_fp8[0])
461
+ dispose_tensor(hidden_states)
739
462
 
740
463
  # Act
741
464
  down_input = torch.empty(
@@ -808,11 +531,9 @@ class DeepEPMoE(EPMoE):
808
531
  def _forward_normal(dispatch_output: DeepEPNormalOutput):
809
532
  if TYPE_CHECKING:
810
533
  assert isinstance(dispatch_output, DeepEPNormalOutput)
811
- hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
812
-
813
- if isinstance(hidden_states, tuple):
814
- per_token_scale = hidden_states[1]
815
- hidden_states = hidden_states[0]
534
+ hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = (
535
+ dispatch_output
536
+ )
816
537
 
817
538
  group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
818
539
  hidden_states.device
@@ -822,7 +543,7 @@ class DeepEPMoE(EPMoE):
822
543
  hidden_states = torch_npu.npu_grouped_matmul(
823
544
  x=[hidden_states],
824
545
  weight=[self.w13_weight.permute(0, 2, 1)],
825
- # per_token_scale=[per_token_scale],
546
+ # per_token_scale=[hidden_states_scale],
826
547
  split_item=2,
827
548
  group_list_type=group_list_type,
828
549
  group_type=0,
@@ -842,7 +563,7 @@ class DeepEPMoE(EPMoE):
842
563
  )[0]
843
564
  else:
844
565
  if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
845
- hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
566
+ hidden_states, hidden_states_scale = torch_npu.npu_dynamic_quant(
846
567
  hidden_states
847
568
  )
848
569
  # gmm1: gate_up_proj
@@ -850,7 +571,7 @@ class DeepEPMoE(EPMoE):
850
571
  x=[hidden_states],
851
572
  weight=[self.w13_weight],
852
573
  scale=[self.w13_weight_scale.to(output_dtype)],
853
- per_token_scale=[per_token_scale],
574
+ per_token_scale=[hidden_states_scale],
854
575
  split_item=2,
855
576
  group_list_type=group_list_type,
856
577
  group_type=0,
@@ -882,11 +603,14 @@ class DeepEPMoE(EPMoE):
882
603
  def _forward_ll(dispatch_output: DeepEPLLOutput):
883
604
  if TYPE_CHECKING:
884
605
  assert isinstance(dispatch_output, DeepEPLLOutput)
885
- hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
886
-
887
- if isinstance(hidden_states, tuple):
888
- per_token_scale = hidden_states[1]
889
- hidden_states = hidden_states[0]
606
+ (
607
+ hidden_states,
608
+ hidden_states_scale,
609
+ topk_ids,
610
+ topk_weights,
611
+ group_list,
612
+ _,
613
+ ) = dispatch_output
890
614
 
891
615
  group_list = group_list.to(torch.int64)
892
616
 
@@ -895,7 +619,7 @@ class DeepEPMoE(EPMoE):
895
619
  hidden_states = torch_npu.npu_grouped_matmul(
896
620
  x=[hidden_states],
897
621
  weight=[self.w13_weight.permute(0, 2, 1)],
898
- # per_token_scale=[per_token_scale],
622
+ # per_token_scale=[hidden_states_scale],
899
623
  split_item=2,
900
624
  group_list_type=group_list_type,
901
625
  group_type=0,
@@ -929,7 +653,7 @@ class DeepEPMoE(EPMoE):
929
653
  hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
930
654
  x=hidden_states,
931
655
  weight_scale=self.w13_weight_scale.to(torch.float32),
932
- activation_scale=per_token_scale,
656
+ activation_scale=hidden_states_scale,
933
657
  bias=None,
934
658
  quant_scale=None,
935
659
  quant_offset=None,
@@ -962,7 +686,7 @@ class DeepEPMoE(EPMoE):
962
686
 
963
687
 
964
688
  def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
965
- if get_moe_a2a_backend().is_deepep():
689
+ if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
966
690
  return DeepEPMoE
967
691
 
968
692
  # NEW: Direct FP4 detection (bypasses EP requirements)
@@ -988,8 +712,6 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
988
712
  return FlashInferFusedMoE
989
713
  if get_moe_runner_backend().is_flashinfer_cutlass():
990
714
  return FusedMoE
991
- if get_moe_expert_parallel_world_size() > 1:
992
- return EPMoE
993
715
  return FusedMoE
994
716
 
995
717