sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +375 -51
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 5
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 5
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 256,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 5
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 2
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 2
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 32,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 256,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 8,
40
+ "num_stages": 2
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 32,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 32,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 8,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 32,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 8,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 256,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 2
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 256,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 256,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 256,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -2,17 +2,20 @@
2
2
 
3
3
  """Fused MoE kernel."""
4
4
 
5
+ from __future__ import annotations
6
+
5
7
  import functools
6
8
  import json
7
9
  import logging
8
10
  import os
9
- from typing import Any, Dict, List, Optional, Tuple
11
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
10
12
 
11
13
  import torch
12
14
  import triton
13
15
  import triton.language as tl
14
16
 
15
- from sglang.srt.layers.moe.topk import TopKOutput
17
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
18
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
16
19
  from sglang.srt.layers.quantization.fp8_kernel import (
17
20
  per_token_group_quant_fp8,
18
21
  scaled_fp8_quant,
@@ -46,13 +49,15 @@ if _is_cuda:
46
49
  elif _is_cpu and _is_cpu_amx_available:
47
50
  pass
48
51
  elif _is_hip:
49
- from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
52
+ from sgl_kernel import gelu_and_mul, silu_and_mul
50
53
 
51
54
  if _use_aiter:
52
55
  try:
53
56
  from aiter import moe_sum
54
57
  except ImportError:
55
58
  raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
59
+ else:
60
+ from vllm import _custom_ops as vllm_ops
56
61
 
57
62
 
58
63
  if _is_cuda or _is_hip:
@@ -1025,8 +1030,8 @@ def inplace_fused_experts(
1025
1030
  a2_scale: Optional[torch.Tensor] = None,
1026
1031
  block_shape: Optional[List[int]] = None,
1027
1032
  routed_scaling_factor: Optional[float] = None,
1028
- activation_alpha: Optional[float] = None,
1029
- swiglu_limit: Optional[float] = None,
1033
+ gemm1_alpha: Optional[float] = None,
1034
+ gemm1_limit: Optional[float] = None,
1030
1035
  ) -> None:
1031
1036
  fused_experts_impl(
1032
1037
  hidden_states,
@@ -1053,8 +1058,8 @@ def inplace_fused_experts(
1053
1058
  block_shape,
1054
1059
  False,
1055
1060
  routed_scaling_factor,
1056
- activation_alpha,
1057
- swiglu_limit,
1061
+ gemm1_alpha,
1062
+ gemm1_limit,
1058
1063
  )
1059
1064
 
1060
1065
 
@@ -1081,8 +1086,8 @@ def inplace_fused_experts_fake(
1081
1086
  a2_scale: Optional[torch.Tensor] = None,
1082
1087
  block_shape: Optional[List[int]] = None,
1083
1088
  routed_scaling_factor: Optional[float] = None,
1084
- activation_alpha: Optional[float] = None,
1085
- swiglu_limit: Optional[float] = None,
1089
+ gemm1_alpha: Optional[float] = None,
1090
+ gemm1_limit: Optional[float] = None,
1086
1091
  ) -> None:
1087
1092
  pass
1088
1093
 
@@ -1119,8 +1124,8 @@ def outplace_fused_experts(
1119
1124
  block_shape: Optional[List[int]] = None,
1120
1125
  no_combine: bool = False,
1121
1126
  routed_scaling_factor: Optional[float] = None,
1122
- activation_alpha: Optional[float] = None,
1123
- swiglu_limit: Optional[float] = None,
1127
+ gemm1_alpha: Optional[float] = None,
1128
+ gemm1_limit: Optional[float] = None,
1124
1129
  ) -> torch.Tensor:
1125
1130
  return fused_experts_impl(
1126
1131
  hidden_states,
@@ -1147,8 +1152,8 @@ def outplace_fused_experts(
1147
1152
  block_shape,
1148
1153
  no_combine=no_combine,
1149
1154
  routed_scaling_factor=routed_scaling_factor,
1150
- activation_alpha=activation_alpha,
1151
- swiglu_limit=swiglu_limit,
1155
+ gemm1_alpha=gemm1_alpha,
1156
+ gemm1_limit=gemm1_limit,
1152
1157
  )
1153
1158
 
1154
1159
 
@@ -1176,8 +1181,8 @@ def outplace_fused_experts_fake(
1176
1181
  block_shape: Optional[List[int]] = None,
1177
1182
  no_combine: bool = False,
1178
1183
  routed_scaling_factor: Optional[float] = None,
1179
- activation_alpha: Optional[float] = None,
1180
- swiglu_limit: Optional[float] = None,
1184
+ gemm1_alpha: Optional[float] = None,
1185
+ gemm1_limit: Optional[float] = None,
1181
1186
  ) -> torch.Tensor:
1182
1187
  return torch.empty_like(hidden_states)
1183
1188
 
@@ -1194,12 +1199,10 @@ def fused_experts(
1194
1199
  hidden_states: torch.Tensor,
1195
1200
  w1: torch.Tensor,
1196
1201
  w2: torch.Tensor,
1197
- topk_output: TopKOutput,
1202
+ topk_output: StandardTopKOutput,
1203
+ moe_runner_config: MoeRunnerConfig,
1198
1204
  b1: Optional[torch.Tensor] = None,
1199
1205
  b2: Optional[torch.Tensor] = None,
1200
- inplace: bool = False,
1201
- activation: str = "silu",
1202
- apply_router_weight_on_input: bool = False,
1203
1206
  use_fp8_w8a8: bool = False,
1204
1207
  use_int8_w8a8: bool = False,
1205
1208
  use_int8_w8a16: bool = False,
@@ -1212,14 +1215,10 @@ def fused_experts(
1212
1215
  a1_scale: Optional[torch.Tensor] = None,
1213
1216
  a2_scale: Optional[torch.Tensor] = None,
1214
1217
  block_shape: Optional[List[int]] = None,
1215
- no_combine: bool = False,
1216
- routed_scaling_factor: Optional[float] = None,
1217
- activation_alpha: Optional[float] = None,
1218
- swiglu_limit: Optional[float] = None,
1219
1218
  ):
1220
1219
  topk_weights, topk_ids, _ = topk_output
1221
- if inplace:
1222
- assert not no_combine, "no combine + inplace makes no sense"
1220
+ if moe_runner_config.inplace:
1221
+ assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
1223
1222
  torch.ops.sglang.inplace_fused_experts(
1224
1223
  hidden_states,
1225
1224
  w1,
@@ -1228,8 +1227,8 @@ def fused_experts(
1228
1227
  topk_ids,
1229
1228
  b1,
1230
1229
  b2,
1231
- activation,
1232
- apply_router_weight_on_input,
1230
+ moe_runner_config.activation,
1231
+ moe_runner_config.apply_router_weight_on_input,
1233
1232
  use_fp8_w8a8,
1234
1233
  use_int8_w8a8,
1235
1234
  use_int8_w8a16,
@@ -1242,9 +1241,9 @@ def fused_experts(
1242
1241
  a1_scale,
1243
1242
  a2_scale,
1244
1243
  block_shape,
1245
- routed_scaling_factor,
1246
- activation_alpha,
1247
- swiglu_limit,
1244
+ moe_runner_config.routed_scaling_factor,
1245
+ moe_runner_config.gemm1_alpha,
1246
+ moe_runner_config.gemm1_clamp_limit,
1248
1247
  )
1249
1248
  return hidden_states
1250
1249
  else:
@@ -1256,8 +1255,8 @@ def fused_experts(
1256
1255
  topk_ids,
1257
1256
  b1,
1258
1257
  b2,
1259
- activation,
1260
- apply_router_weight_on_input,
1258
+ moe_runner_config.activation,
1259
+ moe_runner_config.apply_router_weight_on_input,
1261
1260
  use_fp8_w8a8,
1262
1261
  use_int8_w8a8,
1263
1262
  use_int8_w8a16,
@@ -1270,10 +1269,10 @@ def fused_experts(
1270
1269
  a1_scale,
1271
1270
  a2_scale,
1272
1271
  block_shape,
1273
- no_combine=no_combine,
1274
- routed_scaling_factor=routed_scaling_factor,
1275
- activation_alpha=activation_alpha,
1276
- swiglu_limit=swiglu_limit,
1272
+ no_combine=moe_runner_config.no_combine,
1273
+ routed_scaling_factor=moe_runner_config.routed_scaling_factor,
1274
+ gemm1_alpha=moe_runner_config.gemm1_alpha,
1275
+ gemm1_limit=moe_runner_config.gemm1_clamp_limit,
1277
1276
  )
1278
1277
 
1279
1278
 
@@ -1370,11 +1369,11 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
1370
1369
 
1371
1370
 
1372
1371
  @torch.compile
1373
- def swiglu_with_alpha_and_limit(x, alpha, limit):
1372
+ def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
1374
1373
  gate, up = x[..., ::2], x[..., 1::2]
1375
- gate = gate.clamp(min=None, max=limit)
1376
- up = up.clamp(min=-limit, max=limit)
1377
- return gate * torch.sigmoid(gate * alpha) * (up + 1)
1374
+ gate = gate.clamp(min=None, max=gemm1_limit)
1375
+ up = up.clamp(min=-gemm1_limit, max=gemm1_limit)
1376
+ return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1)
1378
1377
 
1379
1378
 
1380
1379
  def fused_experts_impl(
@@ -1402,8 +1401,8 @@ def fused_experts_impl(
1402
1401
  block_shape: Optional[List[int]] = None,
1403
1402
  no_combine: bool = False,
1404
1403
  routed_scaling_factor: Optional[float] = None,
1405
- activation_alpha: Optional[float] = None,
1406
- swiglu_limit: Optional[float] = None,
1404
+ gemm1_alpha: Optional[float] = None,
1405
+ gemm1_limit: Optional[float] = None,
1407
1406
  ):
1408
1407
  padded_size = padding_size
1409
1408
  if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
@@ -1533,25 +1532,23 @@ def fused_experts_impl(
1533
1532
  block_shape=block_shape,
1534
1533
  )
1535
1534
  if activation == "silu":
1536
- if activation_alpha is not None:
1537
- assert swiglu_limit is not None
1535
+ if gemm1_alpha is not None:
1536
+ assert gemm1_limit is not None
1538
1537
  intermediate_cache2 = swiglu_with_alpha_and_limit(
1539
1538
  intermediate_cache1.view(-1, N),
1540
- activation_alpha,
1541
- swiglu_limit,
1539
+ gemm1_alpha,
1540
+ gemm1_limit,
1542
1541
  )
1543
- elif _is_cuda:
1542
+ elif _is_cuda or _is_hip:
1544
1543
  silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1545
1544
  else:
1546
1545
  vllm_ops.silu_and_mul(
1547
1546
  intermediate_cache2, intermediate_cache1.view(-1, N)
1548
1547
  )
1549
1548
  elif activation == "gelu":
1550
- assert (
1551
- activation_alpha is None
1552
- ), "activation_alpha is not supported for gelu"
1553
- assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
1554
- if _is_cuda:
1549
+ assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
1550
+ assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
1551
+ if _is_cuda or _is_hip:
1555
1552
  gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1556
1553
  else:
1557
1554
  vllm_ops.gelu_and_mul(
@@ -1624,10 +1621,19 @@ def fused_experts_impl(
1624
1621
  out_hidden_states[begin_chunk_idx:end_chunk_idx],
1625
1622
  )
1626
1623
  else:
1627
- vllm_ops.moe_sum(
1628
- intermediate_cache3.view(*intermediate_cache3.shape),
1629
- out_hidden_states[begin_chunk_idx:end_chunk_idx],
1630
- )
1624
+ # According to micro benchmark results, torch.compile can get better performance for small token.
1625
+ if tokens_in_chunk <= 32:
1626
+ moe_sum_reduce_torch_compile(
1627
+ intermediate_cache3.view(*intermediate_cache3.shape),
1628
+ out_hidden_states[begin_chunk_idx:end_chunk_idx],
1629
+ routed_scaling_factor,
1630
+ )
1631
+ else:
1632
+ moe_sum_reduce_triton(
1633
+ intermediate_cache3.view(*intermediate_cache3.shape),
1634
+ out_hidden_states[begin_chunk_idx:end_chunk_idx],
1635
+ routed_scaling_factor,
1636
+ )
1631
1637
  else:
1632
1638
  vllm_ops.moe_sum(
1633
1639
  intermediate_cache3.view(*intermediate_cache3.shape),
@@ -1641,12 +1647,10 @@ def fused_moe(
1641
1647
  hidden_states: torch.Tensor,
1642
1648
  w1: torch.Tensor,
1643
1649
  w2: torch.Tensor,
1644
- topk_output: TopKOutput,
1650
+ topk_output: StandardTopKOutput,
1651
+ moe_runner_config: MoeRunnerConfig = MoeRunnerConfig(),
1645
1652
  b1: Optional[torch.Tensor] = None,
1646
1653
  b2: Optional[torch.Tensor] = None,
1647
- inplace: bool = False,
1648
- activation: str = "silu",
1649
- apply_router_weight_on_input: bool = False,
1650
1654
  use_fp8_w8a8: bool = False,
1651
1655
  use_int8_w8a8: bool = False,
1652
1656
  use_int8_w8a16: bool = False,
@@ -1659,10 +1663,6 @@ def fused_moe(
1659
1663
  a1_scale: Optional[torch.Tensor] = None,
1660
1664
  a2_scale: Optional[torch.Tensor] = None,
1661
1665
  block_shape: Optional[List[int]] = None,
1662
- no_combine: bool = False,
1663
- routed_scaling_factor: Optional[float] = None,
1664
- activation_alpha: Optional[float] = None,
1665
- swiglu_limit: Optional[float] = None,
1666
1666
  ) -> torch.Tensor:
1667
1667
  """
1668
1668
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1672,11 +1672,10 @@ def fused_moe(
1672
1672
  - hidden_states (torch.Tensor): The input tensor to the MoE layer.
1673
1673
  - w1 (torch.Tensor): The first set of expert weights.
1674
1674
  - w2 (torch.Tensor): The second set of expert weights.
1675
- - topk_output (TopKOutput): The top-k output of the experts.
1675
+ - topk_output (StandardTopKOutput): The top-k output of the experts.
1676
+ - moe_runner_config (MoeRunnerConfig): The configuration for the MoE runner.
1676
1677
  - b1 (Optional[torch.Tensor]): Optional bias for w1.
1677
1678
  - b2 (Optional[torch.Tensor]): Optional bias for w2.
1678
- - inplace (bool): If True, perform the operation in-place.
1679
- Defaults to False.
1680
1679
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1681
1680
  products for w1 and w2. Defaults to False.
1682
1681
  - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
@@ -1696,9 +1695,9 @@ def fused_moe(
1696
1695
  a2.
1697
1696
  - block_shape: (Optional[List[int]]): Optional block size for block-wise
1698
1697
  quantization.
1699
- - activation_alpha (Optional[float]): Optional alpha for the activation
1698
+ - gemm1_alpha (Optional[float]): Optional gemm1_alpha for the activation
1700
1699
  function.
1701
- - swiglu_limit (Optional[float]): Optional limit for the swiglu activation
1700
+ - gemm1_limit (Optional[float]): Optional gemm1_limit for the swiglu activation
1702
1701
  function.
1703
1702
 
1704
1703
  Returns:
@@ -1710,11 +1709,9 @@ def fused_moe(
1710
1709
  w1,
1711
1710
  w2,
1712
1711
  topk_output,
1712
+ moe_runner_config=moe_runner_config,
1713
1713
  b1=b1,
1714
1714
  b2=b2,
1715
- inplace=inplace,
1716
- activation=activation,
1717
- apply_router_weight_on_input=apply_router_weight_on_input,
1718
1715
  use_fp8_w8a8=use_fp8_w8a8,
1719
1716
  use_int8_w8a8=use_int8_w8a8,
1720
1717
  use_int8_w8a16=use_int8_w8a16,
@@ -1727,8 +1724,4 @@ def fused_moe(
1727
1724
  a1_scale=a1_scale,
1728
1725
  a2_scale=a2_scale,
1729
1726
  block_shape=block_shape,
1730
- no_combine=no_combine,
1731
- routed_scaling_factor=routed_scaling_factor,
1732
- activation_alpha=activation_alpha,
1733
- swiglu_limit=swiglu_limit,
1734
1727
  )