sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -91,18 +91,10 @@ def cutlass_w4a8_moe(
91
91
  assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
92
92
  assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
93
93
  assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
94
- assert (
95
- w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
96
- and w1_scale.shape[2] == w1_q.shape[1] * 4
97
- ), "W1 scale shape mismatch"
98
- assert (
99
- w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
100
- and w2_scale.shape[2] == w2_q.shape[1] * 4
101
- ), "W2 scale shape mismatch"
102
94
 
103
95
  assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
104
96
  assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
105
- assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
97
+ assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
106
98
  assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
107
99
  num_experts = w1_q.size(0)
108
100
  m = a.size(0)
@@ -155,8 +147,8 @@ def cutlass_w4a8_moe(
155
147
  k,
156
148
  )
157
149
 
158
- c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
159
- c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
150
+ c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
151
+ c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
160
152
 
161
153
  cutlass_w4a8_moe_mm(
162
154
  c1,
@@ -174,7 +166,7 @@ def cutlass_w4a8_moe(
174
166
  topk,
175
167
  )
176
168
 
177
- intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
169
+ intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
178
170
  silu_and_mul(c1, intermediate)
179
171
 
180
172
  intermediate_q = torch.empty(
@@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess(
1362
1362
  gateup_input,
1363
1363
  gateup_input_scale,
1364
1364
  )
1365
+
1366
+
1367
+ @triton.jit
1368
+ def compute_identity_kernel(
1369
+ top_k,
1370
+ hidden_states_ptr,
1371
+ expert_scales_ptr,
1372
+ num_tokens,
1373
+ output_ptr,
1374
+ hidden_dim,
1375
+ scales_stride,
1376
+ BLOCK_SIZE: tl.constexpr,
1377
+ ):
1378
+ pid = tl.program_id(0)
1379
+
1380
+ batch_id = pid // (hidden_dim // BLOCK_SIZE)
1381
+ dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
1382
+
1383
+ if batch_id >= num_tokens or dim_offset >= hidden_dim:
1384
+ return
1385
+
1386
+ h = tl.load(
1387
+ hidden_states_ptr
1388
+ + batch_id * hidden_dim
1389
+ + dim_offset
1390
+ + tl.arange(0, BLOCK_SIZE),
1391
+ mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
1392
+ )
1393
+
1394
+ result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
1395
+ for i in range(top_k):
1396
+ scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
1397
+ result += h * scale
1398
+
1399
+ tl.store(
1400
+ output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE),
1401
+ result,
1402
+ mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
1403
+ )
1404
+
1405
+
1406
+ def zero_experts_compute_triton(
1407
+ expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states
1408
+ ):
1409
+ N = expert_indices.numel()
1410
+ top_k = expert_indices.size(-1)
1411
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
1412
+
1413
+ if zero_expert_type == "identity":
1414
+ zero_expert_mask = expert_indices < num_experts
1415
+ zero_expert_scales = expert_scales.clone()
1416
+ zero_expert_scales[zero_expert_mask] = 0.0
1417
+
1418
+ normal_expert_mask = expert_indices >= num_experts
1419
+ expert_indices[normal_expert_mask] = -1
1420
+ expert_scales[normal_expert_mask] = 0.0
1421
+
1422
+ output = torch.zeros_like(hidden_states).to(hidden_states.device)
1423
+ hidden_dim = hidden_states.size(-1)
1424
+ num_tokens = hidden_states.size(0)
1425
+
1426
+ grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),)
1427
+ compute_identity_kernel[grid](
1428
+ top_k,
1429
+ hidden_states,
1430
+ zero_expert_scales,
1431
+ num_tokens,
1432
+ output,
1433
+ hidden_dim,
1434
+ zero_expert_scales.stride(0),
1435
+ BLOCK_SIZE=256,
1436
+ )
1437
+
1438
+ return output
@@ -35,7 +35,6 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
35
35
 
36
36
  if TYPE_CHECKING:
37
37
  from sglang.srt.layers.moe.token_dispatcher import (
38
- AscendDeepEPLLOutput,
39
38
  DeepEPLLOutput,
40
39
  DeepEPNormalOutput,
41
40
  DispatchOutput,
@@ -114,9 +113,6 @@ class EPMoE(FusedMoE):
114
113
  with_bias=with_bias,
115
114
  )
116
115
 
117
- self.start_expert_id = self.moe_ep_rank * self.num_local_experts
118
- self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
119
-
120
116
  self.intermediate_size = intermediate_size
121
117
 
122
118
  if isinstance(quant_config, Fp8Config):
@@ -232,7 +228,7 @@ class EPMoE(FusedMoE):
232
228
  (
233
229
  _cast_to_e8m0_with_rounding_up(gateup_input_scale)
234
230
  if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
235
- else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
231
+ else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
236
232
  gateup_input_scale
237
233
  )
238
234
  ),
@@ -289,9 +285,7 @@ class EPMoE(FusedMoE):
289
285
  (
290
286
  down_input_scale
291
287
  if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
292
- else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
293
- down_input_scale
294
- )
288
+ else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
295
289
  ),
296
290
  )
297
291
  down_output = torch.empty(
@@ -459,7 +453,7 @@ class DeepEPMoE(EPMoE):
459
453
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
460
454
  return self.forward_aiter(dispatch_output)
461
455
  if _is_npu:
462
- assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
456
+ assert DispatchOutputChecker.format_is_deepep(dispatch_output)
463
457
  return self.forward_npu(dispatch_output)
464
458
  if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
465
459
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
@@ -723,57 +717,124 @@ class DeepEPMoE(EPMoE):
723
717
 
724
718
  def forward_npu(
725
719
  self,
726
- dispatch_output: DeepEPLLOutput,
720
+ dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
727
721
  ):
728
- if TYPE_CHECKING:
729
- assert isinstance(dispatch_output, AscendDeepEPLLOutput)
730
- hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
731
722
  assert self.quant_method is not None
732
723
  assert self.moe_runner_config.activation == "silu"
733
724
 
725
+ import torch_npu
726
+
727
+ from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
728
+
734
729
  # NOTE: Ascend's Dispatch & Combine does not support FP16
735
730
  output_dtype = torch.bfloat16
731
+ group_list_type = 1
736
732
 
737
- pertoken_scale = hidden_states[1]
738
- hidden_states = hidden_states[0]
733
+ def _forward_normal(dispatch_output: DeepEPNormalOutput):
734
+ if TYPE_CHECKING:
735
+ assert isinstance(dispatch_output, DeepEPNormalOutput)
736
+ hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
737
+
738
+ if isinstance(hidden_states, tuple):
739
+ per_token_scale = hidden_states[1]
740
+ hidden_states = hidden_states[0]
741
+ else:
742
+ # dynamic quant
743
+ hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
744
+ hidden_states
745
+ )
739
746
 
740
- group_list_type = 1
741
- seg_indptr = seg_indptr.to(torch.int64)
747
+ group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
748
+ hidden_states.device
749
+ )
742
750
 
743
- import torch_npu
751
+ # gmm1: gate_up_proj
752
+ hidden_states = torch_npu.npu_grouped_matmul(
753
+ x=[hidden_states],
754
+ weight=[self.w13_weight],
755
+ scale=[self.w13_weight_scale.to(output_dtype)],
756
+ per_token_scale=[per_token_scale],
757
+ split_item=2,
758
+ group_list_type=group_list_type,
759
+ group_type=0,
760
+ group_list=group_list,
761
+ output_dtype=output_dtype,
762
+ )[0]
763
+
764
+ # act_fn: swiglu
765
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
766
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
767
+
768
+ # gmm2: down_proj
769
+ hidden_states = torch_npu.npu_grouped_matmul(
770
+ x=[hidden_states],
771
+ weight=[self.w2_weight],
772
+ scale=[self.w2_weight_scale.to(output_dtype)],
773
+ per_token_scale=[swiglu_out_scale],
774
+ split_item=2,
775
+ group_list_type=group_list_type,
776
+ group_type=0,
777
+ group_list=group_list,
778
+ output_dtype=output_dtype,
779
+ )[0]
780
+
781
+ return hidden_states
744
782
 
745
- # gmm1: gate_up_proj
746
- hidden_states = torch_npu.npu_grouped_matmul(
747
- x=[hidden_states],
748
- weight=[self.w13_weight],
749
- scale=[self.w13_weight_scale.to(output_dtype)],
750
- per_token_scale=[pertoken_scale],
751
- split_item=2,
752
- group_list_type=group_list_type,
753
- group_type=0,
754
- group_list=seg_indptr,
755
- output_dtype=output_dtype,
756
- )[0]
757
-
758
- # act_fn: swiglu
759
- hidden_states = torch_npu.npu_swiglu(hidden_states)
760
-
761
- hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
762
-
763
- # gmm2: down_proj
764
- hidden_states = torch_npu.npu_grouped_matmul(
765
- x=[hidden_states],
766
- weight=[self.w2_weight],
767
- scale=[self.w2_weight_scale.to(output_dtype)],
768
- per_token_scale=[swiglu_out_scale],
769
- split_item=2,
770
- group_list_type=group_list_type,
771
- group_type=0,
772
- group_list=seg_indptr,
773
- output_dtype=output_dtype,
774
- )[0]
783
+ def _forward_ll(dispatch_output: DeepEPLLOutput):
784
+ if TYPE_CHECKING:
785
+ assert isinstance(dispatch_output, DeepEPLLOutput)
786
+ hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
787
+
788
+ per_token_scale = hidden_states[1]
789
+ hidden_states = hidden_states[0]
790
+
791
+ group_list = group_list.to(torch.int64)
792
+
793
+ # gmm1: gate_up_proj
794
+ hidden_states = torch_npu.npu_grouped_matmul(
795
+ x=[hidden_states],
796
+ weight=[self.w13_weight],
797
+ split_item=2,
798
+ group_list_type=group_list_type,
799
+ group_type=0,
800
+ group_list=group_list,
801
+ output_dtype=torch.int32,
802
+ )[0]
803
+
804
+ # act_fn: swiglu
805
+ hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
806
+ x=hidden_states,
807
+ weight_scale=self.w13_weight_scale.to(torch.float32),
808
+ activation_scale=per_token_scale,
809
+ bias=None,
810
+ quant_scale=None,
811
+ quant_offset=None,
812
+ group_index=group_list,
813
+ activate_left=True,
814
+ quant_mode=1,
815
+ )
775
816
 
776
- return hidden_states
817
+ # gmm2: down_proj
818
+ hidden_states = torch_npu.npu_grouped_matmul(
819
+ x=[hidden_states],
820
+ weight=[self.w2_weight],
821
+ scale=[self.w2_weight_scale.to(output_dtype)],
822
+ per_token_scale=[swiglu_out_scale],
823
+ split_item=2,
824
+ group_list_type=group_list_type,
825
+ group_type=0,
826
+ group_list=group_list,
827
+ output_dtype=output_dtype,
828
+ )[0]
829
+
830
+ return hidden_states
831
+
832
+ if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
833
+ return _forward_normal(dispatch_output)
834
+ elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
835
+ return _forward_ll(dispatch_output)
836
+ else:
837
+ raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
777
838
 
778
839
 
779
840
  def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
@@ -8,16 +8,18 @@ from torch.nn import functional as F
8
8
 
9
9
  from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
10
10
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
11
+ from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput
11
12
  from sglang.srt.layers.moe.topk import StandardTopKOutput
12
13
 
13
14
 
14
15
  def fused_moe_forward_native(
15
16
  layer: torch.nn.Module,
16
- x: torch.Tensor,
17
- topk_output: StandardTopKOutput,
18
- moe_runner_config: MoeRunnerConfig,
17
+ dispatch_output: StandardDispatchOutput,
19
18
  ) -> torch.Tensor:
20
19
 
20
+ x, topk_output = dispatch_output
21
+ moe_runner_config = layer.moe_runner_config
22
+
21
23
  if moe_runner_config.apply_router_weight_on_input:
22
24
  raise NotImplementedError()
23
25
 
@@ -1,16 +1,18 @@
1
1
  from contextlib import contextmanager
2
2
  from typing import Any, Dict, Optional
3
3
 
4
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
5
- fused_experts,
4
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
5
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
6
6
  get_config_file_name,
7
- moe_align_block_size,
8
7
  try_get_optimal_moe_config,
9
8
  )
10
9
  from sglang.srt.layers.moe.fused_moe_triton.layer import (
11
10
  FusedMoE,
12
11
  FusedMoeWeightScaleSupported,
13
12
  )
13
+ from sglang.srt.layers.moe.fused_moe_triton.moe_align_block_size import (
14
+ moe_align_block_size,
15
+ )
14
16
 
15
17
  _config: Optional[Dict[str, Any]] = None
16
18
 
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 32,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 256,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 256,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 256,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 256,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 256,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 8,
112
+ "num_stages": 5
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 256,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 8,
136
+ "num_stages": 5
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 256,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 5
145
+ }
146
+ }