sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, List, Optional, Union
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
5
5
 
6
6
  import torch
7
- import triton
8
- import triton.language as tl
9
7
 
10
- from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
8
+ from sglang.srt import single_batch_overlap
9
+ from sglang.srt.layers import deep_gemm_wrapper
11
10
  from sglang.srt.layers.moe import (
12
11
  get_deepep_mode,
13
12
  get_moe_a2a_backend,
@@ -17,31 +16,21 @@ from sglang.srt.layers.moe import (
17
16
  from sglang.srt.layers.moe.ep_moe.kernels import (
18
17
  ep_gather,
19
18
  ep_scatter,
20
- moe_ep_deepgemm_preprocess,
21
- post_reorder_triton_kernel,
22
19
  silu_and_mul_masked_post_quant_fwd,
23
20
  tma_align_input_scale,
24
21
  )
25
22
  from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
26
23
  from sglang.srt.layers.moe.topk import TopKOutput
27
- from sglang.srt.layers.quantization import deep_gemm_wrapper
28
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
29
25
  from sglang.srt.layers.quantization.fp8 import Fp8Config
30
26
  from sglang.srt.layers.quantization.fp8_kernel import (
31
27
  is_fp8_fnuz,
32
28
  sglang_per_token_group_quant_fp8,
33
29
  )
34
- from sglang.srt.managers.schedule_batch import global_server_args_dict
35
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
- from sglang.srt.offloader import get_offloader
37
- from sglang.srt.utils import (
38
- ceil_div,
39
- dispose_tensor,
40
- get_bool_env_var,
41
- is_cuda,
42
- is_hip,
43
- is_npu,
44
- )
30
+ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
31
+ from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
32
+ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
33
+ from sglang.srt.utils.offloader import get_offloader
45
34
 
46
35
  if TYPE_CHECKING:
47
36
  from sglang.srt.layers.moe.token_dispatcher import (
@@ -65,29 +54,14 @@ if _use_aiter:
65
54
  logger = logging.getLogger(__name__)
66
55
 
67
56
 
68
- # TODO(kaixih@nvidia): ideally we should merge this logic into
69
- # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
70
- @torch.compile
71
- def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
72
- temp = x.to(torch.float32).view(torch.int32)
73
- exp = torch.bitwise_right_shift(temp, 23)
74
- mant = torch.bitwise_and(temp, 0x7FFFFF)
75
- is_ru = torch.logical_and(
76
- torch.logical_and((mant > 0), (exp != 0xFE)),
77
- ~torch.logical_and((exp == 0), (mant <= 0x400000)),
78
- )
79
- exp = torch.where(is_ru, exp + 1, exp)
80
- new_x = exp.to(torch.uint8).view(torch.int)
81
- return new_x.transpose(1, 2).contiguous().transpose(1, 2)
82
-
83
-
84
- class EPMoE(FusedMoE):
57
+ class DeepEPMoE(FusedMoE):
85
58
  """
86
- MoE Expert Parallel Impl
87
-
88
-
59
+ MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
60
+ Mooncake EP shares the same class, as they expose the same interface.
89
61
  """
90
62
 
63
+ _has_printed = False
64
+
91
65
  def __init__(
92
66
  self,
93
67
  num_experts: int,
@@ -101,291 +75,37 @@ class EPMoE(FusedMoE):
101
75
  prefix: str = "",
102
76
  activation: str = "silu",
103
77
  routed_scaling_factor: Optional[float] = None,
104
- gemm1_alpha: Optional[float] = None,
105
- gemm1_clamp_limit: Optional[float] = None,
106
- with_bias: bool = False,
107
78
  ):
108
79
  super().__init__(
109
80
  num_experts=num_experts,
81
+ top_k=top_k,
110
82
  hidden_size=hidden_size,
111
83
  intermediate_size=intermediate_size,
112
- num_fused_shared_experts=num_fused_shared_experts,
113
84
  layer_id=layer_id,
114
- top_k=top_k,
85
+ num_fused_shared_experts=num_fused_shared_experts,
115
86
  params_dtype=params_dtype,
116
87
  quant_config=quant_config,
117
88
  prefix=prefix,
118
89
  activation=activation,
119
- # apply_router_weight_on_input=apply_router_weight_on_input,
120
90
  routed_scaling_factor=routed_scaling_factor,
121
- gemm1_alpha=gemm1_alpha,
122
- gemm1_clamp_limit=gemm1_clamp_limit,
123
- with_bias=with_bias,
124
91
  )
125
92
 
126
- self.intermediate_size = intermediate_size
127
-
128
93
  if isinstance(quant_config, Fp8Config):
129
94
  self.use_block_quant = getattr(self.quant_method, "block_quant", False)
130
- self.block_shape = (
131
- self.quant_method.quant_config.weight_block_size
132
- if self.use_block_quant
133
- else None
134
- )
135
95
  self.use_fp8_w8a8 = True
136
96
  self.fp8_dtype = torch.float8_e4m3fn
137
- self.activation_scheme = quant_config.activation_scheme
138
- else:
97
+ self.use_w4afp8 = False
98
+ elif isinstance(quant_config, W4AFp8Config):
99
+ self.use_w4afp8 = True
139
100
  self.use_fp8_w8a8 = False
140
101
  self.use_block_quant = False
141
- self.block_shape = None
142
- self.activation_scheme = None
143
-
144
- def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
145
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
146
- return self.forward_deepgemm(hidden_states, topk_output)
147
102
  else:
148
- return super().forward(hidden_states, topk_output)
149
-
150
- def forward_deepgemm(
151
- self,
152
- hidden_states: torch.Tensor,
153
- topk_output: TopKOutput,
154
- ):
155
-
156
- self.w13_weight_fp8 = (
157
- self.w13_weight,
158
- (
159
- self.w13_weight_scale_inv
160
- if self.use_block_quant
161
- else self.w13_weight_scale
162
- ),
163
- )
164
- self.w2_weight_fp8 = (
165
- self.w2_weight,
166
- self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
167
- )
168
-
169
- assert self.quant_method is not None
170
- assert self.moe_runner_config.activation == "silu"
171
-
172
- hidden_states_shape = hidden_states.shape
173
- hidden_states_dtype = hidden_states.dtype
174
- hidden_states_device = hidden_states.device
175
-
176
- topk_weights, topk_ids, _ = topk_output
177
-
178
- if not self.use_block_quant:
179
- # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
180
- scale_block_size = 128
181
- w13_weight_scale_n = 2 * (
182
- (self.intermediate_size + scale_block_size - 1) // scale_block_size
183
- )
184
- w13_weight_scale_k = (
185
- hidden_states_shape[-1] + scale_block_size - 1
186
- ) // scale_block_size
187
- w13_weight_scale = (
188
- self.w13_weight_scale.unsqueeze(1)
189
- .repeat_interleave(w13_weight_scale_n, dim=1)
190
- .unsqueeze(2)
191
- .repeat_interleave(w13_weight_scale_k, dim=2)
192
- )
193
- self.w13_weight_fp8 = (
194
- self.w13_weight,
195
- w13_weight_scale,
196
- )
197
- w2_weight_scale_n = (
198
- hidden_states_shape[-1] + scale_block_size - 1
199
- ) // scale_block_size
200
- w2_weight_scale_k = (
201
- self.intermediate_size + scale_block_size - 1
202
- ) // scale_block_size
203
- w2_weight_scale = (
204
- self.w2_weight_scale.unsqueeze(1)
205
- .repeat_interleave(w2_weight_scale_n, dim=1)
206
- .unsqueeze(2)
207
- .repeat_interleave(w2_weight_scale_k, dim=2)
208
- )
209
- self.w2_weight_fp8 = (
210
- self.w2_weight,
211
- w2_weight_scale,
212
- )
213
-
214
- # PreReorder
215
- m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
216
- moe_ep_deepgemm_preprocess(
217
- topk_ids,
218
- self.num_experts,
219
- hidden_states,
220
- self.top_k,
221
- self.start_expert_id,
222
- self.end_expert_id,
223
- self.block_shape,
224
- )
225
- )
226
-
227
- dispose_tensor(hidden_states)
228
-
229
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
230
- b, s_mn, s_k = gateup_input_scale.shape
231
- assert (
232
- s_mn % 4 == 0 and s_k % 4 == 0
233
- ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
234
-
235
- # GroupGemm-0
236
- gateup_input_fp8 = (
237
- gateup_input,
238
- (
239
- _cast_to_e8m0_with_rounding_up(gateup_input_scale)
240
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
241
- else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
242
- gateup_input_scale
243
- )
244
- ),
245
- )
246
- num_groups, m, k = gateup_input_fp8[0].size()
247
- n = self.w13_weight.size(1)
248
- gateup_output = torch.empty(
249
- (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
250
- )
251
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
252
- gateup_input_fp8,
253
- self.w13_weight_fp8,
254
- gateup_output,
255
- masked_m,
256
- expected_m,
257
- )
258
- del gateup_input
259
- del gateup_input_fp8
260
-
261
- # Act
262
- down_input = torch.empty(
263
- (
264
- gateup_output.shape[0],
265
- gateup_output.shape[1],
266
- gateup_output.shape[2] // 2,
267
- ),
268
- device=hidden_states_device,
269
- dtype=self.fp8_dtype,
270
- )
271
- scale_block_size = 128
272
- down_input_scale = torch.empty(
273
- (
274
- gateup_output.shape[0],
275
- gateup_output.shape[1],
276
- gateup_output.shape[2] // 2 // scale_block_size,
277
- ),
278
- device=hidden_states_device,
279
- dtype=torch.float32,
280
- )
281
- silu_and_mul_masked_post_quant_fwd(
282
- gateup_output,
283
- down_input,
284
- down_input_scale,
285
- scale_block_size,
286
- masked_m,
287
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
288
- )
289
- del gateup_output
290
-
291
- # GroupGemm-1
292
- n = self.w2_weight.size(1)
293
- down_input_fp8 = (
294
- down_input,
295
- (
296
- down_input_scale
297
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
298
- else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
299
- ),
300
- )
301
- down_output = torch.empty(
302
- (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
303
- )
304
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
305
- down_input_fp8,
306
- self.w2_weight_fp8,
307
- down_output,
308
- masked_m,
309
- expected_m,
310
- )
311
- del down_input
312
- del down_input_fp8
313
-
314
- # PostReorder
315
- output = torch.empty(
316
- hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
317
- )
318
- post_reorder_triton_kernel[(hidden_states_shape[0],)](
319
- down_output,
320
- output,
321
- src2dst,
322
- topk_ids,
323
- topk_weights,
324
- self.start_expert_id,
325
- self.end_expert_id,
326
- self.top_k,
327
- hidden_states_shape[1],
328
- m_max * self.start_expert_id,
329
- BLOCK_SIZE=512,
330
- )
331
- if self.moe_runner_config.routed_scaling_factor is not None:
332
- output *= self.moe_runner_config.routed_scaling_factor
333
- return output
334
-
335
-
336
- class DeepEPMoE(EPMoE):
337
- """
338
- MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
339
- """
340
-
341
- _has_printed = False
103
+ self.use_fp8_w8a8 = False
104
+ self.use_block_quant = False
105
+ self.use_w4afp8 = False
342
106
 
343
- def __init__(
344
- self,
345
- num_experts: int,
346
- top_k: int,
347
- hidden_size: int,
348
- intermediate_size: int,
349
- layer_id: int,
350
- num_fused_shared_experts: int = 0,
351
- params_dtype: Optional[torch.dtype] = None,
352
- quant_config: Optional[QuantizationConfig] = None,
353
- prefix: str = "",
354
- activation: str = "silu",
355
- routed_scaling_factor: Optional[float] = None,
356
- ):
357
- super().__init__(
358
- num_experts=num_experts,
359
- top_k=top_k,
360
- hidden_size=hidden_size,
361
- intermediate_size=intermediate_size,
362
- layer_id=layer_id,
363
- num_fused_shared_experts=num_fused_shared_experts,
364
- params_dtype=params_dtype,
365
- quant_config=quant_config,
366
- prefix=prefix,
367
- activation=activation,
368
- routed_scaling_factor=routed_scaling_factor,
369
- )
370
107
  self.deepep_mode = get_deepep_mode()
371
108
 
372
- # TODO: move to the beginning of the file
373
- from sglang.srt.distributed.parallel_state import get_tp_group
374
- from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
375
-
376
- self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
377
- group=get_tp_group().device_group,
378
- router_topk=self.top_k,
379
- permute_fusion=True,
380
- num_experts=self.num_experts,
381
- num_local_experts=self.num_local_experts,
382
- hidden_size=hidden_size,
383
- params_dtype=params_dtype,
384
- deepep_mode=self.deepep_mode,
385
- async_finish=True, # TODO
386
- return_recv_hook=True,
387
- )
388
-
389
109
  if self.deepep_mode.enable_low_latency() and not _is_npu:
390
110
  # NPU supports low_latency deepep without deepgemm
391
111
  assert (
@@ -409,7 +129,7 @@ class DeepEPMoE(EPMoE):
409
129
  self.w13_weight,
410
130
  (
411
131
  self.w13_weight_scale_inv
412
- if self.use_block_quant
132
+ if self.use_block_quant or self.use_w4afp8
413
133
  else self.w13_weight_scale
414
134
  ),
415
135
  )
@@ -417,7 +137,7 @@ class DeepEPMoE(EPMoE):
417
137
  self.w2_weight,
418
138
  (
419
139
  self.w2_weight_scale_inv
420
- if self.use_block_quant
140
+ if self.use_block_quant or self.use_w4afp8
421
141
  else self.w2_weight_scale
422
142
  ),
423
143
  )
@@ -425,37 +145,38 @@ class DeepEPMoE(EPMoE):
425
145
  def forward(
426
146
  self,
427
147
  hidden_states: torch.Tensor,
428
- topk_idx: torch.Tensor,
429
- topk_weights: torch.Tensor,
430
- forward_batch: ForwardBatch,
148
+ topk_output: TopKOutput,
149
+ forward_shared_experts=None,
150
+ alt_stream=None,
151
+ disable_sbo=False,
431
152
  ):
432
- dispatch_output = self.dispatch(
433
- hidden_states, topk_idx, topk_weights, forward_batch
434
- )
435
- hidden_states = self.moe_impl(dispatch_output)
436
- hidden_states = self.combine(
437
- hidden_states,
438
- dispatch_output.topk_idx,
439
- dispatch_output.topk_weights,
440
- forward_batch,
153
+
154
+ # We have to call SBO inside MoE to be compatible with hooks used in offloading
155
+ return single_batch_overlap.execute_sbo(
156
+ hidden_states=hidden_states,
157
+ topk_output=topk_output,
158
+ # SBO args
159
+ experts=self,
160
+ forward_shared_experts=forward_shared_experts,
161
+ alt_stream=alt_stream,
162
+ disable_sbo=disable_sbo,
441
163
  )
442
- return hidden_states
443
164
 
444
165
  def dispatch(
445
166
  self,
446
167
  hidden_states: torch.Tensor,
447
- topk_idx: torch.Tensor,
448
- topk_weights: torch.Tensor,
449
- forward_batch: ForwardBatch,
168
+ topk_output: TopKOutput,
450
169
  ):
451
- return self.deepep_dispatcher.dispatch(
170
+ return self.dispatcher.dispatch(
452
171
  hidden_states=hidden_states,
453
- topk_idx=topk_idx,
454
- topk_weights=topk_weights,
455
- forward_batch=forward_batch,
172
+ topk_output=topk_output,
456
173
  )
457
174
 
458
- def moe_impl(self, dispatch_output: DispatchOutput):
175
+ def run_moe_core(
176
+ self,
177
+ dispatch_output: DispatchOutput,
178
+ down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
179
+ ):
459
180
  from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
460
181
 
461
182
  if _use_aiter:
@@ -466,12 +187,20 @@ class DeepEPMoE(EPMoE):
466
187
  assert DispatchOutputChecker.format_is_deepep(dispatch_output)
467
188
  return self.forward_npu(dispatch_output)
468
189
  if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
190
+ if self.use_w4afp8:
191
+ return self.forward_cutlass_w4afp8(dispatch_output)
469
192
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
470
193
  return self.forward_deepgemm_contiguous(dispatch_output)
471
194
  elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
472
- if get_moe_runner_backend().is_flashinfer_cutedsl():
473
- return self.forward_flashinfer_cutedsl(dispatch_output)
195
+ if (
196
+ get_moe_runner_backend().is_flashinfer_cutedsl()
197
+ and self.quant_config.get_name() == "modelopt_fp4"
198
+ ):
199
+ return self.forward_flashinfer_cutedsl(
200
+ dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
201
+ )
474
202
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
203
+ assert down_gemm_overlap_args is None
475
204
  return self.forward_deepgemm_masked(dispatch_output)
476
205
  else:
477
206
  raise ValueError(
@@ -481,24 +210,24 @@ class DeepEPMoE(EPMoE):
481
210
  def combine(
482
211
  self,
483
212
  hidden_states: torch.Tensor,
484
- topk_idx: torch.Tensor,
213
+ topk_ids: torch.Tensor,
485
214
  topk_weights: torch.Tensor,
486
- forward_batch: ForwardBatch,
215
+ overlap_args: Optional[Dict[str, Any]] = None,
487
216
  ):
488
- return self.deepep_dispatcher.combine(
217
+ return self.dispatcher.combine(
489
218
  hidden_states=hidden_states,
490
- topk_idx=topk_idx,
219
+ topk_ids=topk_ids,
491
220
  topk_weights=topk_weights,
492
- forward_batch=forward_batch,
221
+ overlap_args=overlap_args,
493
222
  )
494
223
 
495
224
  def forward_aiter(
496
225
  self,
497
226
  dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
498
227
  ):
499
- hidden_states, topk_idx, topk_weights = (
228
+ hidden_states, topk_ids, topk_weights = (
500
229
  dispatch_output.hidden_states,
501
- dispatch_output.topk_idx,
230
+ dispatch_output.topk_ids,
502
231
  dispatch_output.topk_weights,
503
232
  )
504
233
  if hidden_states.shape[0] == 0:
@@ -506,15 +235,15 @@ class DeepEPMoE(EPMoE):
506
235
  # in original deepep, idx == -1 meaning invalid and will not be processed.
507
236
  # aiter does not accept -1, we use a expert mask to make these idx invalid
508
237
  # (idx == num_local_experts) meaning not used in aiter fused_moe
509
- topk_idx_copy = topk_idx.to(torch.int32)
510
- topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
238
+ topk_ids_copy = topk_ids.to(torch.int32)
239
+ topk_ids_copy[topk_ids_copy == -1] = self.num_local_experts
511
240
 
512
241
  return fused_moe(
513
242
  hidden_states,
514
243
  self.w13_weight,
515
244
  self.w2_weight,
516
245
  topk_weights,
517
- topk_idx_copy,
246
+ topk_ids_copy,
518
247
  w1_scale=self.w13_weight_scale_inv,
519
248
  w2_scale=self.w2_weight_scale_inv,
520
249
  quant_type=QuantType.per_128x128,
@@ -530,22 +259,24 @@ class DeepEPMoE(EPMoE):
530
259
  self,
531
260
  dispatch_output: DeepEPNormalOutput,
532
261
  ):
533
- hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
534
- dispatch_output
535
- )
536
- hidden_states_fp8, hidden_states_scale = hidden_states_fp8
262
+ (
263
+ hidden_states,
264
+ hidden_states_scale,
265
+ topk_ids,
266
+ topk_weights,
267
+ num_recv_tokens_per_expert,
268
+ ) = dispatch_output
537
269
  assert self.quant_method is not None
538
270
  assert self.moe_runner_config.activation == "silu"
539
271
  if num_recv_tokens_per_expert is None:
540
- return hidden_states_fp8.bfloat16()
272
+ return hidden_states.bfloat16()
541
273
  all_tokens = sum(num_recv_tokens_per_expert)
542
274
  if all_tokens <= 0:
543
- return hidden_states_fp8.bfloat16()
544
- M, K = hidden_states_fp8.size()
275
+ return hidden_states.bfloat16()
276
+ M, K = hidden_states.size()
545
277
  N = self.w13_weight.size(1)
546
278
  scale_block_size = 128
547
279
 
548
- # TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
549
280
  w13_weight_fp8 = (
550
281
  self.w13_weight,
551
282
  (
@@ -563,35 +294,35 @@ class DeepEPMoE(EPMoE):
563
294
  ),
564
295
  )
565
296
 
566
- hidden_states_fp8_shape = hidden_states_fp8.shape
567
- hidden_states_fp8_device = hidden_states_fp8.device
568
- hidden_states_fp8_dtype = hidden_states_fp8.dtype
297
+ hidden_states_shape = hidden_states.shape
298
+ hidden_states_device = hidden_states.device
299
+ hidden_states_dtype = hidden_states.dtype
569
300
 
570
301
  input_tensor = [
571
302
  torch.empty(
572
303
  (all_tokens, K),
573
- device=hidden_states_fp8.device,
574
- dtype=hidden_states_fp8.dtype,
304
+ device=hidden_states.device,
305
+ dtype=hidden_states.dtype,
575
306
  ),
576
307
  (
577
308
  # TODO check whether need `zeros`
578
309
  torch.zeros(
579
310
  (ceil_div(K // 128, 4), all_tokens),
580
- device=hidden_states_fp8.device,
311
+ device=hidden_states.device,
581
312
  dtype=torch.int,
582
313
  ).transpose(0, 1)
583
314
  if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
584
315
  else torch.empty(
585
316
  (all_tokens, K // 128),
586
- device=hidden_states_fp8.device,
317
+ device=hidden_states.device,
587
318
  dtype=torch.float32,
588
319
  )
589
320
  ),
590
321
  ]
591
322
  m_indices = torch.empty(
592
- all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
323
+ all_tokens, device=hidden_states.device, dtype=torch.int32
593
324
  )
594
- output_index = torch.empty_like(topk_idx)
325
+ output_index = torch.empty_like(topk_ids)
595
326
 
596
327
  if get_offloader().forbid_copy_engine_usage:
597
328
  num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
@@ -607,9 +338,9 @@ class DeepEPMoE(EPMoE):
607
338
  expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
608
339
 
609
340
  ep_scatter(
610
- hidden_states_fp8,
341
+ hidden_states,
611
342
  hidden_states_scale,
612
- topk_idx,
343
+ topk_ids,
613
344
  num_recv_tokens_per_expert_gpu,
614
345
  expert_start_loc,
615
346
  input_tensor[0],
@@ -618,11 +349,11 @@ class DeepEPMoE(EPMoE):
618
349
  output_index,
619
350
  scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
620
351
  )
621
- dispose_tensor(hidden_states_fp8)
352
+ dispose_tensor(hidden_states)
622
353
 
623
354
  gateup_output = torch.empty(
624
355
  (all_tokens, N),
625
- device=hidden_states_fp8_device,
356
+ device=hidden_states_device,
626
357
  dtype=torch.bfloat16,
627
358
  )
628
359
  if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
@@ -643,7 +374,7 @@ class DeepEPMoE(EPMoE):
643
374
  del gateup_output
644
375
  down_output = torch.empty(
645
376
  (all_tokens, K),
646
- device=hidden_states_fp8_device,
377
+ device=hidden_states_device,
647
378
  dtype=torch.bfloat16,
648
379
  )
649
380
  down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
@@ -665,53 +396,69 @@ class DeepEPMoE(EPMoE):
665
396
  del down_input_fp8, down_input_scale
666
397
 
667
398
  gather_out = torch.empty(
668
- hidden_states_fp8_shape,
669
- device=hidden_states_fp8_device,
399
+ hidden_states_shape,
400
+ device=hidden_states_device,
670
401
  dtype=torch.bfloat16,
671
402
  )
672
- ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
403
+ ep_gather(down_output, topk_ids, topk_weights, output_index, gather_out)
673
404
 
674
405
  return gather_out
675
406
 
676
407
  def forward_flashinfer_cutedsl(
677
408
  self,
678
409
  dispatch_output: DeepEPLLOutput,
410
+ down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
679
411
  ):
680
- hidden_states, _, _, masked_m, _ = dispatch_output
412
+ hidden_states, hidden_states_scale, _, _, masked_m, _ = dispatch_output
681
413
  assert self.quant_method is not None
682
414
  assert self.moe_runner_config.activation == "silu"
683
415
 
684
416
  output = self.quant_method.apply_without_routing_weights(
685
417
  layer=self,
686
- x=hidden_states,
418
+ x=(hidden_states, hidden_states_scale),
687
419
  masked_m=masked_m,
688
420
  moe_runner_config=self.moe_runner_config,
421
+ down_gemm_overlap_args=down_gemm_overlap_args,
689
422
  )
690
423
  return output
691
424
 
425
+ def forward_cutlass_w4afp8(
426
+ self,
427
+ dispatch_output: DeepEPNormalOutput,
428
+ ):
429
+ assert self.moe_runner_config.activation == "silu"
430
+ assert isinstance(self.quant_method, W4AFp8MoEMethod)
431
+ return self.quant_method.apply_deepep_normal(
432
+ layer=self,
433
+ dispatch_output=dispatch_output,
434
+ )
435
+
692
436
  def forward_deepgemm_masked(
693
437
  self,
694
438
  dispatch_output: DeepEPLLOutput,
695
439
  ):
696
- hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
440
+ hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
697
441
  assert self.quant_method is not None
698
442
  assert self.moe_runner_config.activation == "silu"
443
+ assert (
444
+ hidden_states_scale.dtype == torch.float32
445
+ ), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}"
699
446
 
700
447
  # GroupGemm-0
701
- num_groups, m, k = hidden_states_fp8[0].size()
448
+ num_groups, m, k = hidden_states.size()
702
449
  n = self.w13_weight.size(1)
703
450
  expected_m = min(expected_m, m)
704
451
  gateup_output = torch.empty(
705
- (num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
452
+ (num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16
706
453
  )
707
454
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
708
- hidden_states_fp8,
455
+ (hidden_states, hidden_states_scale),
709
456
  self.w13_weight_fp8,
710
457
  gateup_output,
711
458
  masked_m,
712
459
  expected_m,
713
460
  )
714
- dispose_tensor(hidden_states_fp8[0])
461
+ dispose_tensor(hidden_states)
715
462
 
716
463
  # Act
717
464
  down_input = torch.empty(
@@ -784,99 +531,149 @@ class DeepEPMoE(EPMoE):
784
531
  def _forward_normal(dispatch_output: DeepEPNormalOutput):
785
532
  if TYPE_CHECKING:
786
533
  assert isinstance(dispatch_output, DeepEPNormalOutput)
787
- hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
788
-
789
- if isinstance(hidden_states, tuple):
790
- per_token_scale = hidden_states[1]
791
- hidden_states = hidden_states[0]
792
- else:
793
- # dynamic quant
794
- hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
795
- hidden_states
796
- )
534
+ hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = (
535
+ dispatch_output
536
+ )
797
537
 
798
538
  group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
799
539
  hidden_states.device
800
540
  )
541
+ if self.w13_weight.dtype != torch.int8:
542
+ # gmm1: gate_up_proj
543
+ hidden_states = torch_npu.npu_grouped_matmul(
544
+ x=[hidden_states],
545
+ weight=[self.w13_weight.permute(0, 2, 1)],
546
+ # per_token_scale=[hidden_states_scale],
547
+ split_item=2,
548
+ group_list_type=group_list_type,
549
+ group_type=0,
550
+ group_list=group_list,
551
+ output_dtype=output_dtype,
552
+ )[0]
553
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
554
+ # gmm2: down_proj
555
+ hidden_states = torch_npu.npu_grouped_matmul(
556
+ x=[hidden_states],
557
+ weight=[self.w2_weight.permute(0, 2, 1)],
558
+ split_item=2,
559
+ group_list_type=group_list_type,
560
+ group_type=0,
561
+ group_list=group_list,
562
+ output_dtype=output_dtype,
563
+ )[0]
564
+ else:
565
+ if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
566
+ hidden_states, hidden_states_scale = torch_npu.npu_dynamic_quant(
567
+ hidden_states
568
+ )
569
+ # gmm1: gate_up_proj
570
+ hidden_states = torch_npu.npu_grouped_matmul(
571
+ x=[hidden_states],
572
+ weight=[self.w13_weight],
573
+ scale=[self.w13_weight_scale.to(output_dtype)],
574
+ per_token_scale=[hidden_states_scale],
575
+ split_item=2,
576
+ group_list_type=group_list_type,
577
+ group_type=0,
578
+ group_list=group_list,
579
+ output_dtype=output_dtype,
580
+ )[0]
581
+
582
+ # act_fn: swiglu
583
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
584
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
585
+ hidden_states
586
+ )
801
587
 
802
- # gmm1: gate_up_proj
803
- hidden_states = torch_npu.npu_grouped_matmul(
804
- x=[hidden_states],
805
- weight=[self.w13_weight],
806
- scale=[self.w13_weight_scale.to(output_dtype)],
807
- per_token_scale=[per_token_scale],
808
- split_item=2,
809
- group_list_type=group_list_type,
810
- group_type=0,
811
- group_list=group_list,
812
- output_dtype=output_dtype,
813
- )[0]
814
-
815
- # act_fn: swiglu
816
- hidden_states = torch_npu.npu_swiglu(hidden_states)
817
- hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
818
-
819
- # gmm2: down_proj
820
- hidden_states = torch_npu.npu_grouped_matmul(
821
- x=[hidden_states],
822
- weight=[self.w2_weight],
823
- scale=[self.w2_weight_scale.to(output_dtype)],
824
- per_token_scale=[swiglu_out_scale],
825
- split_item=2,
826
- group_list_type=group_list_type,
827
- group_type=0,
828
- group_list=group_list,
829
- output_dtype=output_dtype,
830
- )[0]
588
+ # gmm2: down_proj
589
+ hidden_states = torch_npu.npu_grouped_matmul(
590
+ x=[hidden_states],
591
+ weight=[self.w2_weight],
592
+ scale=[self.w2_weight_scale.to(output_dtype)],
593
+ per_token_scale=[swiglu_out_scale],
594
+ split_item=2,
595
+ group_list_type=group_list_type,
596
+ group_type=0,
597
+ group_list=group_list,
598
+ output_dtype=output_dtype,
599
+ )[0]
831
600
 
832
601
  return hidden_states
833
602
 
834
603
  def _forward_ll(dispatch_output: DeepEPLLOutput):
835
604
  if TYPE_CHECKING:
836
605
  assert isinstance(dispatch_output, DeepEPLLOutput)
837
- hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
838
-
839
- per_token_scale = hidden_states[1]
840
- hidden_states = hidden_states[0]
606
+ (
607
+ hidden_states,
608
+ hidden_states_scale,
609
+ topk_ids,
610
+ topk_weights,
611
+ group_list,
612
+ _,
613
+ ) = dispatch_output
841
614
 
842
615
  group_list = group_list.to(torch.int64)
843
616
 
844
- # gmm1: gate_up_proj
845
- hidden_states = torch_npu.npu_grouped_matmul(
846
- x=[hidden_states],
847
- weight=[self.w13_weight],
848
- split_item=2,
849
- group_list_type=group_list_type,
850
- group_type=0,
851
- group_list=group_list,
852
- output_dtype=torch.int32,
853
- )[0]
854
-
855
- # act_fn: swiglu
856
- hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
857
- x=hidden_states,
858
- weight_scale=self.w13_weight_scale.to(torch.float32),
859
- activation_scale=per_token_scale,
860
- bias=None,
861
- quant_scale=None,
862
- quant_offset=None,
863
- group_index=group_list,
864
- activate_left=True,
865
- quant_mode=1,
866
- )
617
+ if self.w13_weight.dtype != torch.int8:
618
+ # gmm1: gate_up_proj
619
+ hidden_states = torch_npu.npu_grouped_matmul(
620
+ x=[hidden_states],
621
+ weight=[self.w13_weight.permute(0, 2, 1)],
622
+ # per_token_scale=[hidden_states_scale],
623
+ split_item=2,
624
+ group_list_type=group_list_type,
625
+ group_type=0,
626
+ group_list=group_list,
627
+ output_dtype=output_dtype,
628
+ )[0]
629
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
630
+ # gmm2: down_proj
631
+ hidden_states = torch_npu.npu_grouped_matmul(
632
+ x=[hidden_states],
633
+ weight=[self.w2_weight.permute(0, 2, 1)],
634
+ split_item=2,
635
+ group_list_type=group_list_type,
636
+ group_type=0,
637
+ group_list=group_list,
638
+ output_dtype=output_dtype,
639
+ )[0]
640
+ else:
641
+ # gmm1: gate_up_proj
642
+ hidden_states = torch_npu.npu_grouped_matmul(
643
+ x=[hidden_states],
644
+ weight=[self.w13_weight],
645
+ split_item=2,
646
+ group_list_type=group_list_type,
647
+ group_type=0,
648
+ group_list=group_list,
649
+ output_dtype=torch.int32,
650
+ )[0]
651
+
652
+ # act_fn: swiglu
653
+ hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
654
+ x=hidden_states,
655
+ weight_scale=self.w13_weight_scale.to(torch.float32),
656
+ activation_scale=hidden_states_scale,
657
+ bias=None,
658
+ quant_scale=None,
659
+ quant_offset=None,
660
+ group_index=group_list,
661
+ activate_left=True,
662
+ quant_mode=1,
663
+ )
867
664
 
868
- # gmm2: down_proj
869
- hidden_states = torch_npu.npu_grouped_matmul(
870
- x=[hidden_states],
871
- weight=[self.w2_weight],
872
- scale=[self.w2_weight_scale.to(output_dtype)],
873
- per_token_scale=[swiglu_out_scale],
874
- split_item=2,
875
- group_list_type=group_list_type,
876
- group_type=0,
877
- group_list=group_list,
878
- output_dtype=output_dtype,
879
- )[0]
665
+ # gmm2: down_proj
666
+ hidden_states = torch_npu.npu_grouped_matmul(
667
+ x=[hidden_states],
668
+ weight=[self.w2_weight],
669
+ scale=[self.w2_weight_scale.to(output_dtype)],
670
+ per_token_scale=[swiglu_out_scale],
671
+ split_item=2,
672
+ group_list_type=group_list_type,
673
+ group_type=0,
674
+ group_list=group_list,
675
+ output_dtype=output_dtype,
676
+ )[0]
880
677
 
881
678
  return hidden_states
882
679
 
@@ -889,7 +686,7 @@ class DeepEPMoE(EPMoE):
889
686
 
890
687
 
891
688
  def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
892
- if get_moe_a2a_backend().is_deepep():
689
+ if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
893
690
  return DeepEPMoE
894
691
 
895
692
  # NEW: Direct FP4 detection (bypasses EP requirements)
@@ -915,8 +712,6 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
915
712
  return FlashInferFusedMoE
916
713
  if get_moe_runner_backend().is_flashinfer_cutlass():
917
714
  return FusedMoE
918
- if get_moe_expert_parallel_world_size() > 1:
919
- return EPMoE
920
715
  return FusedMoE
921
716
 
922
717