sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 8,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 2
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 8,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 8,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 2
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -1,3 +1,4 @@
1
+ # NOTE: this file will be separated into sglang/srt/layers/moe/moe_runner/triton_utils.py
1
2
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py
2
3
 
3
4
  """Fused MoE kernel."""
@@ -6,13 +7,12 @@ from __future__ import annotations
6
7
 
7
8
  import functools
8
9
  import os
9
- from typing import List, Optional
10
+ from typing import TYPE_CHECKING, List, Optional
10
11
 
11
12
  import torch
12
13
  import triton.language as tl
13
14
 
14
15
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
15
- from sglang.srt.layers.moe.topk import StandardTopKOutput
16
16
  from sglang.srt.utils import (
17
17
  cpu_has_amx_support,
18
18
  direct_register_custom_op,
@@ -26,6 +26,9 @@ from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_c
26
26
  from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton
27
27
  from .moe_align_block_size import moe_align_block_size
28
28
 
29
+ if TYPE_CHECKING:
30
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
31
+
29
32
  _is_hip = is_hip()
30
33
  _is_cuda = is_cuda()
31
34
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -43,7 +43,7 @@ def get_moe_configs(
43
43
  be picked and the associated configuration chosen to invoke the kernel.
44
44
  """
45
45
  # Supported Triton versions, should be sorted from the newest to the oldest
46
- supported_triton_versions = ["3.3.1", "3.2.0", "3.1.0"]
46
+ supported_triton_versions = ["3.4.0", "3.3.1", "3.2.0", "3.1.0"]
47
47
 
48
48
  # First look up if an optimized configuration is available in the configs
49
49
  # directory
@@ -735,29 +735,32 @@ def _moe_sum_reduce_kernel(
735
735
  token_block_id = tl.program_id(0)
736
736
  dim_block_id = tl.program_id(1)
737
737
 
738
- token_start = token_block_id * BLOCK_M
739
- token_end = min((token_block_id + 1) * BLOCK_M, token_num)
738
+ offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
739
+ offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
740
740
 
741
- dim_start = dim_block_id * BLOCK_DIM
742
- dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
741
+ mask_token = offs_token < token_num
742
+ mask_dim = offs_dim < hidden_dim
743
743
 
744
- offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
744
+ base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
745
745
 
746
- for token_index in range(token_start, token_end):
747
- accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
748
- input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
749
- for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
750
- tmp = tl.load(
751
- input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
752
- )
753
- accumulator += tmp
754
- accumulator = accumulator * routed_scaling_factor
755
- store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
756
- tl.store(
757
- store_t_ptr,
758
- accumulator.to(input_ptr.dtype.element_ty),
759
- mask=offs_dim < dim_end,
746
+ accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
747
+
748
+ for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
749
+ tile = tl.load(
750
+ base_ptrs + i * input_stride_1,
751
+ mask=mask_token[:, None] & mask_dim[None, :],
752
+ other=0.0,
760
753
  )
754
+ accumulator += tile.to(tl.float32)
755
+ accumulator *= routed_scaling_factor
756
+
757
+ # -------- Write back --------
758
+ store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
759
+ tl.store(
760
+ store_ptrs,
761
+ accumulator.to(input_ptr.dtype.element_ty),
762
+ mask=mask_token[:, None] & mask_dim[None, :],
763
+ )
761
764
 
762
765
 
763
766
  def moe_sum_reduce_triton(
@@ -772,7 +775,7 @@ def moe_sum_reduce_triton(
772
775
  BLOCK_M = 1
773
776
  BLOCK_DIM = 2048
774
777
  NUM_STAGE = 1
775
- num_warps = 8
778
+ num_warps = 16
776
779
 
777
780
  grid = (
778
781
  triton.cdiv(token_num, BLOCK_M),
@@ -23,8 +23,14 @@ from sglang.srt.layers.moe import (
23
23
  get_moe_runner_backend,
24
24
  should_use_flashinfer_trtllm_moe,
25
25
  )
26
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
27
+ CombineInput,
28
+ StandardDispatcher,
29
+ StandardDispatchOutput,
30
+ )
26
31
  from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
27
32
  from sglang.srt.layers.quantization.base_config import (
33
+ FusedMoEMethodBase,
28
34
  QuantizationConfig,
29
35
  QuantizeMethodBase,
30
36
  )
@@ -68,16 +74,6 @@ if should_use_flashinfer_trtllm_moe():
68
74
  logger = logging.getLogger(__name__)
69
75
 
70
76
 
71
- def _is_fp4_quantization_enabled():
72
- """Check if ModelOpt FP4 quantization is enabled."""
73
- try:
74
- # Use the same simple check that works for class selection
75
- quantization = global_server_args_dict.get("quantization")
76
- return quantization == "modelopt_fp4"
77
- except:
78
- return False
79
-
80
-
81
77
  def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
82
78
  # Guess tokens per expert assuming perfect expert distribution first.
83
79
  num_tokens_per_expert = (num_tokens * top_k) // num_experts
@@ -152,16 +148,6 @@ class FusedMoE(torch.nn.Module):
152
148
  self.expert_map_cpu = None
153
149
  self.expert_map_gpu = None
154
150
 
155
- self.moe_runner_config = MoeRunnerConfig(
156
- activation=activation,
157
- apply_router_weight_on_input=apply_router_weight_on_input,
158
- inplace=inplace,
159
- no_combine=no_combine,
160
- routed_scaling_factor=routed_scaling_factor,
161
- gemm1_alpha=gemm1_alpha,
162
- gemm1_clamp_limit=gemm1_clamp_limit,
163
- )
164
-
165
151
  enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
166
152
 
167
153
  if enable_flashinfer_cutlass_moe and quant_config is None:
@@ -196,13 +182,6 @@ class FusedMoE(torch.nn.Module):
196
182
  self.use_presharded_weights = use_presharded_weights
197
183
 
198
184
  self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
199
- if quant_config is None:
200
- self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
201
- self.use_triton_kernels
202
- )
203
- else:
204
- self.quant_method = quant_config.get_quant_method(self, prefix)
205
- assert self.quant_method is not None
206
185
 
207
186
  self.quant_config = quant_config
208
187
  self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
@@ -213,12 +192,40 @@ class FusedMoE(torch.nn.Module):
213
192
  and self.use_flashinfer_mxfp4_moe
214
193
  ):
215
194
  hidden_size = round_up(hidden_size, 256)
195
+ self.hidden_size = hidden_size
196
+
197
+ self.moe_runner_config = MoeRunnerConfig(
198
+ num_experts=num_experts,
199
+ num_local_experts=self.num_local_experts,
200
+ hidden_size=hidden_size,
201
+ intermediate_size_per_partition=self.intermediate_size_per_partition,
202
+ layer_id=layer_id,
203
+ top_k=top_k,
204
+ num_fused_shared_experts=num_fused_shared_experts,
205
+ params_dtype=params_dtype,
206
+ activation=activation,
207
+ apply_router_weight_on_input=apply_router_weight_on_input,
208
+ inplace=inplace,
209
+ no_combine=no_combine,
210
+ routed_scaling_factor=routed_scaling_factor,
211
+ gemm1_alpha=gemm1_alpha,
212
+ gemm1_clamp_limit=gemm1_clamp_limit,
213
+ )
214
+
215
+ if quant_config is None:
216
+ self.quant_method: FusedMoEMethodBase = UnquantizedFusedMoEMethod(
217
+ self.use_triton_kernels
218
+ )
219
+ else:
220
+ self.quant_method: FusedMoEMethodBase = quant_config.get_quant_method(
221
+ self, prefix
222
+ )
223
+ assert self.quant_method is not None
224
+
216
225
  self.quant_method.create_weights(
217
226
  layer=self,
218
227
  num_experts=self.num_local_experts,
219
228
  hidden_size=hidden_size,
220
- # FIXME: figure out which intermediate_size to use
221
- intermediate_size=self.intermediate_size_per_partition,
222
229
  intermediate_size_per_partition=self.intermediate_size_per_partition,
223
230
  params_dtype=params_dtype,
224
231
  weight_loader=(
@@ -229,6 +236,9 @@ class FusedMoE(torch.nn.Module):
229
236
  with_bias=with_bias,
230
237
  )
231
238
 
239
+ self.quant_method.create_moe_runner(self, self.moe_runner_config)
240
+ self.dispatcher = StandardDispatcher()
241
+
232
242
  def _load_per_tensor_weight_scale(
233
243
  self,
234
244
  shard_id: str,
@@ -522,10 +532,12 @@ class FusedMoE(torch.nn.Module):
522
532
  shard_id: str,
523
533
  expert_id: int,
524
534
  ) -> None:
535
+ # WARN: This makes the `expert_id` mean "local" and "global" in different cases
536
+ if not getattr(param, "_sglang_require_global_experts", False):
537
+ expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
538
+ if expert_id == -1:
539
+ return
525
540
 
526
- expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
527
- if expert_id == -1:
528
- return
529
541
  self._weight_loader_impl(
530
542
  param=param,
531
543
  loaded_weight=loaded_weight,
@@ -594,8 +606,10 @@ class FusedMoE(torch.nn.Module):
594
606
  loaded_weight = loaded_weight.to(param.data.device)
595
607
 
596
608
  if (
597
- "compressed" in self.quant_method.__class__.__name__.lower()
598
- or "w4afp8" in self.quant_config.get_name()
609
+ (
610
+ "compressed" in self.quant_method.__class__.__name__.lower()
611
+ or "w4afp8" in self.quant_config.get_name()
612
+ )
599
613
  and (param.data[expert_id] != 1).any()
600
614
  and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
601
615
  ):
@@ -811,16 +825,17 @@ class FusedMoE(torch.nn.Module):
811
825
  elif TopKOutputChecker.format_is_triton_kernel(topk_output):
812
826
  raise NotImplementedError()
813
827
 
814
- # Matrix multiply.
815
- with use_symmetric_memory(get_tp_group()) as sm:
828
+ dispatch_output = self.dispatcher.dispatch(
829
+ hidden_states=hidden_states, topk_output=topk_output
830
+ )
816
831
 
817
- final_hidden_states = self.quant_method.apply(
818
- layer=self,
819
- x=hidden_states,
820
- topk_output=topk_output,
821
- moe_runner_config=self.moe_runner_config,
822
- )
823
- sm.tag(final_hidden_states)
832
+ # TODO: consider using symmetric memory
833
+ combine_input = self.quant_method.apply(
834
+ layer=self,
835
+ dispatch_output=dispatch_output,
836
+ )
837
+
838
+ final_hidden_states = self.dispatcher.combine(combine_input)
824
839
 
825
840
  final_hidden_states = final_hidden_states[
826
841
  ..., :origin_hidden_states_dim
@@ -953,9 +968,9 @@ class FlashInferFusedMoE(FusedMoE):
953
968
  # Matrix multiply.
954
969
  final_hidden_states = self.quant_method.apply_with_router_logits(
955
970
  layer=self,
956
- x=hidden_states,
957
- topk_output=topk_output,
958
- moe_runner_config=self.moe_runner_config,
971
+ dispatch_output=StandardDispatchOutput(
972
+ hidden_states=hidden_states, topk_output=topk_output
973
+ ),
959
974
  )
960
975
 
961
976
  if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
@@ -1055,16 +1070,3 @@ class FlashInferFP4MoE(FusedMoE):
1055
1070
  )[0]
1056
1071
 
1057
1072
  return result
1058
-
1059
-
1060
- def get_fused_moe_impl_class():
1061
- """Factory function to get the appropriate FusedMoE implementation class."""
1062
- if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
1063
- # Use FP4 variant when FP4 quantization is enabled
1064
- return FlashInferFP4MoE
1065
- elif should_use_flashinfer_trtllm_moe():
1066
- # Use regular FlashInfer variant for non-FP4 FlashInfer cases
1067
- return FlashInferFusedMoE
1068
- else:
1069
- # Default case
1070
- return FusedMoE
@@ -1,3 +1,4 @@
1
1
  from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
2
+ from sglang.srt.layers.moe.moe_runner.runner import MoeRunner
2
3
 
3
- __all__ = ["MoeRunnerConfig"]
4
+ __all__ = ["MoeRunnerConfig", "MoeRunner"]