sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -91,18 +91,10 @@ def cutlass_w4a8_moe(
91
91
  assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
92
92
  assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
93
93
  assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
94
- assert (
95
- w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
96
- and w1_scale.shape[2] == w1_q.shape[1] * 4
97
- ), "W1 scale shape mismatch"
98
- assert (
99
- w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
100
- and w2_scale.shape[2] == w2_q.shape[1] * 4
101
- ), "W2 scale shape mismatch"
102
94
 
103
95
  assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
104
96
  assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
105
- assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
97
+ assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
106
98
  assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
107
99
  num_experts = w1_q.size(0)
108
100
  m = a.size(0)
@@ -155,8 +147,8 @@ def cutlass_w4a8_moe(
155
147
  k,
156
148
  )
157
149
 
158
- c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
159
- c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
150
+ c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
151
+ c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
160
152
 
161
153
  cutlass_w4a8_moe_mm(
162
154
  c1,
@@ -174,7 +166,7 @@ def cutlass_w4a8_moe(
174
166
  topk,
175
167
  )
176
168
 
177
- intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
169
+ intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
178
170
  silu_and_mul(c1, intermediate)
179
171
 
180
172
  intermediate_q = torch.empty(
@@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess(
1362
1362
  gateup_input,
1363
1363
  gateup_input_scale,
1364
1364
  )
1365
+
1366
+
1367
+ @triton.jit
1368
+ def compute_identity_kernel(
1369
+ top_k,
1370
+ hidden_states_ptr,
1371
+ expert_scales_ptr,
1372
+ num_tokens,
1373
+ output_ptr,
1374
+ hidden_dim,
1375
+ scales_stride,
1376
+ BLOCK_SIZE: tl.constexpr,
1377
+ ):
1378
+ pid = tl.program_id(0)
1379
+
1380
+ batch_id = pid // (hidden_dim // BLOCK_SIZE)
1381
+ dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
1382
+
1383
+ if batch_id >= num_tokens or dim_offset >= hidden_dim:
1384
+ return
1385
+
1386
+ h = tl.load(
1387
+ hidden_states_ptr
1388
+ + batch_id * hidden_dim
1389
+ + dim_offset
1390
+ + tl.arange(0, BLOCK_SIZE),
1391
+ mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
1392
+ )
1393
+
1394
+ result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
1395
+ for i in range(top_k):
1396
+ scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
1397
+ result += h * scale
1398
+
1399
+ tl.store(
1400
+ output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE),
1401
+ result,
1402
+ mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
1403
+ )
1404
+
1405
+
1406
+ def zero_experts_compute_triton(
1407
+ expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states
1408
+ ):
1409
+ N = expert_indices.numel()
1410
+ top_k = expert_indices.size(-1)
1411
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
1412
+
1413
+ if zero_expert_type == "identity":
1414
+ zero_expert_mask = expert_indices < num_experts
1415
+ zero_expert_scales = expert_scales.clone()
1416
+ zero_expert_scales[zero_expert_mask] = 0.0
1417
+
1418
+ normal_expert_mask = expert_indices >= num_experts
1419
+ expert_indices[normal_expert_mask] = -1
1420
+ expert_scales[normal_expert_mask] = 0.0
1421
+
1422
+ output = torch.zeros_like(hidden_states).to(hidden_states.device)
1423
+ hidden_dim = hidden_states.size(-1)
1424
+ num_tokens = hidden_states.size(0)
1425
+
1426
+ grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),)
1427
+ compute_identity_kernel[grid](
1428
+ top_k,
1429
+ hidden_states,
1430
+ zero_expert_scales,
1431
+ num_tokens,
1432
+ output,
1433
+ hidden_dim,
1434
+ zero_expert_scales.stride(0),
1435
+ BLOCK_SIZE=256,
1436
+ )
1437
+
1438
+ return output
@@ -35,7 +35,6 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
35
35
 
36
36
  if TYPE_CHECKING:
37
37
  from sglang.srt.layers.moe.token_dispatcher import (
38
- AscendDeepEPLLOutput,
39
38
  DeepEPLLOutput,
40
39
  DeepEPNormalOutput,
41
40
  DispatchOutput,
@@ -114,9 +113,6 @@ class EPMoE(FusedMoE):
114
113
  with_bias=with_bias,
115
114
  )
116
115
 
117
- self.start_expert_id = self.moe_ep_rank * self.num_local_experts
118
- self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
119
-
120
116
  self.intermediate_size = intermediate_size
121
117
 
122
118
  if isinstance(quant_config, Fp8Config):
@@ -232,7 +228,7 @@ class EPMoE(FusedMoE):
232
228
  (
233
229
  _cast_to_e8m0_with_rounding_up(gateup_input_scale)
234
230
  if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
235
- else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
231
+ else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
236
232
  gateup_input_scale
237
233
  )
238
234
  ),
@@ -248,7 +244,6 @@ class EPMoE(FusedMoE):
248
244
  gateup_output,
249
245
  masked_m,
250
246
  expected_m,
251
- recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
252
247
  )
253
248
  del gateup_input
254
249
  del gateup_input_fp8
@@ -290,9 +285,7 @@ class EPMoE(FusedMoE):
290
285
  (
291
286
  down_input_scale
292
287
  if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
293
- else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
294
- down_input_scale
295
- )
288
+ else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
296
289
  ),
297
290
  )
298
291
  down_output = torch.empty(
@@ -304,7 +297,6 @@ class EPMoE(FusedMoE):
304
297
  down_output,
305
298
  masked_m,
306
299
  expected_m,
307
- recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
308
300
  )
309
301
  del down_input
310
302
  del down_input_fp8
@@ -461,7 +453,7 @@ class DeepEPMoE(EPMoE):
461
453
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
462
454
  return self.forward_aiter(dispatch_output)
463
455
  if _is_npu:
464
- assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
456
+ assert DispatchOutputChecker.format_is_deepep(dispatch_output)
465
457
  return self.forward_npu(dispatch_output)
466
458
  if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
467
459
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
@@ -667,7 +659,6 @@ class DeepEPMoE(EPMoE):
667
659
  gateup_output,
668
660
  masked_m,
669
661
  expected_m,
670
- recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
671
662
  )
672
663
  dispose_tensor(hidden_states_fp8[0])
673
664
 
@@ -708,9 +699,7 @@ class DeepEPMoE(EPMoE):
708
699
  (
709
700
  down_input_scale
710
701
  if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
711
- else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
712
- down_input_scale
713
- )
702
+ else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
714
703
  ),
715
704
  )
716
705
  down_output = torch.empty(
@@ -722,64 +711,130 @@ class DeepEPMoE(EPMoE):
722
711
  down_output,
723
712
  masked_m,
724
713
  expected_m,
725
- recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
726
714
  )
727
715
 
728
716
  return down_output
729
717
 
730
718
  def forward_npu(
731
719
  self,
732
- dispatch_output: DeepEPLLOutput,
720
+ dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
733
721
  ):
734
- if TYPE_CHECKING:
735
- assert isinstance(dispatch_output, AscendDeepEPLLOutput)
736
- hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
737
722
  assert self.quant_method is not None
738
723
  assert self.moe_runner_config.activation == "silu"
739
724
 
725
+ import torch_npu
726
+
727
+ from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
728
+
740
729
  # NOTE: Ascend's Dispatch & Combine does not support FP16
741
730
  output_dtype = torch.bfloat16
731
+ group_list_type = 1
742
732
 
743
- pertoken_scale = hidden_states[1]
744
- hidden_states = hidden_states[0]
733
+ def _forward_normal(dispatch_output: DeepEPNormalOutput):
734
+ if TYPE_CHECKING:
735
+ assert isinstance(dispatch_output, DeepEPNormalOutput)
736
+ hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
737
+
738
+ if isinstance(hidden_states, tuple):
739
+ per_token_scale = hidden_states[1]
740
+ hidden_states = hidden_states[0]
741
+ else:
742
+ # dynamic quant
743
+ hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
744
+ hidden_states
745
+ )
745
746
 
746
- group_list_type = 1
747
- seg_indptr = seg_indptr.to(torch.int64)
747
+ group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
748
+ hidden_states.device
749
+ )
748
750
 
749
- import torch_npu
751
+ # gmm1: gate_up_proj
752
+ hidden_states = torch_npu.npu_grouped_matmul(
753
+ x=[hidden_states],
754
+ weight=[self.w13_weight],
755
+ scale=[self.w13_weight_scale.to(output_dtype)],
756
+ per_token_scale=[per_token_scale],
757
+ split_item=2,
758
+ group_list_type=group_list_type,
759
+ group_type=0,
760
+ group_list=group_list,
761
+ output_dtype=output_dtype,
762
+ )[0]
763
+
764
+ # act_fn: swiglu
765
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
766
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
767
+
768
+ # gmm2: down_proj
769
+ hidden_states = torch_npu.npu_grouped_matmul(
770
+ x=[hidden_states],
771
+ weight=[self.w2_weight],
772
+ scale=[self.w2_weight_scale.to(output_dtype)],
773
+ per_token_scale=[swiglu_out_scale],
774
+ split_item=2,
775
+ group_list_type=group_list_type,
776
+ group_type=0,
777
+ group_list=group_list,
778
+ output_dtype=output_dtype,
779
+ )[0]
780
+
781
+ return hidden_states
750
782
 
751
- # gmm1: gate_up_proj
752
- hidden_states = torch_npu.npu_grouped_matmul(
753
- x=[hidden_states],
754
- weight=[self.w13_weight],
755
- scale=[self.w13_weight_scale.to(output_dtype)],
756
- per_token_scale=[pertoken_scale],
757
- split_item=2,
758
- group_list_type=group_list_type,
759
- group_type=0,
760
- group_list=seg_indptr,
761
- output_dtype=output_dtype,
762
- )[0]
763
-
764
- # act_fn: swiglu
765
- hidden_states = torch_npu.npu_swiglu(hidden_states)
766
-
767
- hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
768
-
769
- # gmm2: down_proj
770
- hidden_states = torch_npu.npu_grouped_matmul(
771
- x=[hidden_states],
772
- weight=[self.w2_weight],
773
- scale=[self.w2_weight_scale.to(output_dtype)],
774
- per_token_scale=[swiglu_out_scale],
775
- split_item=2,
776
- group_list_type=group_list_type,
777
- group_type=0,
778
- group_list=seg_indptr,
779
- output_dtype=output_dtype,
780
- )[0]
783
+ def _forward_ll(dispatch_output: DeepEPLLOutput):
784
+ if TYPE_CHECKING:
785
+ assert isinstance(dispatch_output, DeepEPLLOutput)
786
+ hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
787
+
788
+ per_token_scale = hidden_states[1]
789
+ hidden_states = hidden_states[0]
790
+
791
+ group_list = group_list.to(torch.int64)
792
+
793
+ # gmm1: gate_up_proj
794
+ hidden_states = torch_npu.npu_grouped_matmul(
795
+ x=[hidden_states],
796
+ weight=[self.w13_weight],
797
+ split_item=2,
798
+ group_list_type=group_list_type,
799
+ group_type=0,
800
+ group_list=group_list,
801
+ output_dtype=torch.int32,
802
+ )[0]
803
+
804
+ # act_fn: swiglu
805
+ hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
806
+ x=hidden_states,
807
+ weight_scale=self.w13_weight_scale.to(torch.float32),
808
+ activation_scale=per_token_scale,
809
+ bias=None,
810
+ quant_scale=None,
811
+ quant_offset=None,
812
+ group_index=group_list,
813
+ activate_left=True,
814
+ quant_mode=1,
815
+ )
781
816
 
782
- return hidden_states
817
+ # gmm2: down_proj
818
+ hidden_states = torch_npu.npu_grouped_matmul(
819
+ x=[hidden_states],
820
+ weight=[self.w2_weight],
821
+ scale=[self.w2_weight_scale.to(output_dtype)],
822
+ per_token_scale=[swiglu_out_scale],
823
+ split_item=2,
824
+ group_list_type=group_list_type,
825
+ group_type=0,
826
+ group_list=group_list,
827
+ output_dtype=output_dtype,
828
+ )[0]
829
+
830
+ return hidden_states
831
+
832
+ if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
833
+ return _forward_normal(dispatch_output)
834
+ elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
835
+ return _forward_ll(dispatch_output)
836
+ else:
837
+ raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
783
838
 
784
839
 
785
840
  def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
@@ -8,16 +8,18 @@ from torch.nn import functional as F
8
8
 
9
9
  from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
10
10
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
11
+ from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput
11
12
  from sglang.srt.layers.moe.topk import StandardTopKOutput
12
13
 
13
14
 
14
15
  def fused_moe_forward_native(
15
16
  layer: torch.nn.Module,
16
- x: torch.Tensor,
17
- topk_output: StandardTopKOutput,
18
- moe_runner_config: MoeRunnerConfig,
17
+ dispatch_output: StandardDispatchOutput,
19
18
  ) -> torch.Tensor:
20
19
 
20
+ x, topk_output = dispatch_output
21
+ moe_runner_config = layer.moe_runner_config
22
+
21
23
  if moe_runner_config.apply_router_weight_on_input:
22
24
  raise NotImplementedError()
23
25
 
@@ -1,16 +1,18 @@
1
1
  from contextlib import contextmanager
2
2
  from typing import Any, Dict, Optional
3
3
 
4
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
5
- fused_experts,
4
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
5
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
6
6
  get_config_file_name,
7
- moe_align_block_size,
8
7
  try_get_optimal_moe_config,
9
8
  )
10
9
  from sglang.srt.layers.moe.fused_moe_triton.layer import (
11
10
  FusedMoE,
12
11
  FusedMoeWeightScaleSupported,
13
12
  )
13
+ from sglang.srt.layers.moe.fused_moe_triton.moe_align_block_size import (
14
+ moe_align_block_size,
15
+ )
14
16
 
15
17
  _config: Optional[Dict[str, Any]] = None
16
18
 
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 32,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }