sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 32,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 256,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 32,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 64,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 64,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 256,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -6,13 +6,13 @@ import functools
6
6
  import json
7
7
  import logging
8
8
  import os
9
- from typing import Any, Callable, Dict, List, Optional, Tuple
9
+ from typing import Any, Dict, List, Optional, Tuple
10
10
 
11
11
  import torch
12
12
  import triton
13
13
  import triton.language as tl
14
14
 
15
- from sglang.srt.layers.moe.topk import select_experts
15
+ from sglang.srt.layers.moe.topk import TopKOutput
16
16
  from sglang.srt.layers.quantization.fp8_kernel import (
17
17
  per_token_group_quant_fp8,
18
18
  scaled_fp8_quant,
@@ -39,14 +39,21 @@ _is_hip = is_hip()
39
39
  _is_cuda = is_cuda()
40
40
  _is_cpu_amx_available = cpu_has_amx_support()
41
41
  _is_cpu = is_cpu()
42
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
42
43
 
43
44
  if _is_cuda:
44
45
  from sgl_kernel import gelu_and_mul, silu_and_mul
45
46
  elif _is_cpu and _is_cpu_amx_available:
46
47
  pass
47
- else:
48
- from vllm import _custom_ops as vllm_ops
49
- from vllm._custom_ops import scaled_fp8_quant
48
+ elif _is_hip:
49
+ from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
50
+
51
+ if _use_aiter:
52
+ try:
53
+ from aiter import moe_sum
54
+ except ImportError:
55
+ raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
56
+
50
57
 
51
58
  if _is_cuda or _is_hip:
52
59
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
@@ -54,9 +61,6 @@ if _is_cuda or _is_hip:
54
61
 
55
62
  logger = logging.getLogger(__name__)
56
63
  padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
57
- enable_moe_align_block_size_triton = bool(
58
- int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
59
- )
60
64
 
61
65
 
62
66
  @triton.jit
@@ -524,190 +528,6 @@ def fused_moe_kernel(
524
528
  tl.store(c_ptrs, accumulator, mask=c_mask)
525
529
 
526
530
 
527
- @triton.jit
528
- def moe_align_block_size_stage1(
529
- topk_ids_ptr,
530
- tokens_cnts_ptr,
531
- num_experts: tl.constexpr,
532
- numel: tl.constexpr,
533
- tokens_per_thread: tl.constexpr,
534
- ):
535
- pid = tl.program_id(0)
536
-
537
- start_idx = pid * tokens_per_thread
538
-
539
- off_c = (pid + 1) * num_experts
540
-
541
- for i in range(tokens_per_thread):
542
- if start_idx + i < numel:
543
- idx = tl.load(topk_ids_ptr + start_idx + i)
544
- token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
545
- tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
546
-
547
-
548
- @triton.jit
549
- def moe_align_block_size_stage2(
550
- tokens_cnts_ptr,
551
- num_experts: tl.constexpr,
552
- ):
553
- pid = tl.program_id(0)
554
-
555
- last_cnt = 0
556
- for i in range(1, num_experts + 1):
557
- token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
558
- last_cnt = last_cnt + token_cnt
559
- tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
560
-
561
-
562
- @triton.jit
563
- def moe_align_block_size_stage3(
564
- total_tokens_post_pad_ptr,
565
- tokens_cnts_ptr,
566
- cumsum_ptr,
567
- num_experts: tl.constexpr,
568
- block_size: tl.constexpr,
569
- ):
570
- last_cumsum = 0
571
- off_cnt = num_experts * num_experts
572
- for i in range(1, num_experts + 1):
573
- token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
574
- last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
575
- tl.store(cumsum_ptr + i, last_cumsum)
576
- tl.store(total_tokens_post_pad_ptr, last_cumsum)
577
-
578
-
579
- @triton.jit
580
- def moe_align_block_size_stage4(
581
- topk_ids_ptr,
582
- sorted_token_ids_ptr,
583
- expert_ids_ptr,
584
- tokens_cnts_ptr,
585
- cumsum_ptr,
586
- num_experts: tl.constexpr,
587
- block_size: tl.constexpr,
588
- numel: tl.constexpr,
589
- tokens_per_thread: tl.constexpr,
590
- ):
591
- pid = tl.program_id(0)
592
- start_idx = tl.load(cumsum_ptr + pid)
593
- end_idx = tl.load(cumsum_ptr + pid + 1)
594
-
595
- for i in range(start_idx, end_idx, block_size):
596
- tl.store(expert_ids_ptr + i // block_size, pid)
597
-
598
- start_idx = pid * tokens_per_thread
599
- off_t = pid * num_experts
600
-
601
- for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
602
- expert_id = tl.load(topk_ids_ptr + i)
603
- token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
604
- rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
605
- tl.store(sorted_token_ids_ptr + rank_post_pad, i)
606
- tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
607
-
608
-
609
- def moe_align_block_size_triton(
610
- topk_ids: torch.Tensor,
611
- num_experts: int,
612
- block_size: int,
613
- sorted_token_ids: torch.Tensor,
614
- expert_ids: torch.Tensor,
615
- num_tokens_post_pad: torch.Tensor,
616
- ) -> None:
617
- numel = topk_ids.numel()
618
- grid = (num_experts,)
619
- tokens_cnts = torch.zeros(
620
- (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
621
- )
622
- cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
623
- tokens_per_thread = ceil_div(numel, num_experts)
624
-
625
- moe_align_block_size_stage1[grid](
626
- topk_ids,
627
- tokens_cnts,
628
- num_experts,
629
- numel,
630
- tokens_per_thread,
631
- )
632
- moe_align_block_size_stage2[grid](
633
- tokens_cnts,
634
- num_experts,
635
- )
636
- moe_align_block_size_stage3[(1,)](
637
- num_tokens_post_pad,
638
- tokens_cnts,
639
- cumsum,
640
- num_experts,
641
- block_size,
642
- )
643
- moe_align_block_size_stage4[grid](
644
- topk_ids,
645
- sorted_token_ids,
646
- expert_ids,
647
- tokens_cnts,
648
- cumsum,
649
- num_experts,
650
- block_size,
651
- numel,
652
- tokens_per_thread,
653
- )
654
-
655
-
656
- @triton.jit
657
- def init_sorted_ids_and_cumsum_buffer_kernel(
658
- sorted_ids_ptr,
659
- cumsum_buffer_ptr,
660
- max_num_tokens_padded,
661
- topk_ids_numel,
662
- num_experts: tl.constexpr,
663
- BLOCK_SIZE: tl.constexpr,
664
- ALIGNED_NUM_EXPERTS_P1: tl.constexpr,
665
- ):
666
- pid = tl.program_id(0)
667
- offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
668
-
669
- sorted_ids_blocks = tl.cdiv(max_num_tokens_padded, BLOCK_SIZE)
670
-
671
- if pid < sorted_ids_blocks:
672
- mask = offsets < max_num_tokens_padded
673
- tl.store(
674
- sorted_ids_ptr + offsets,
675
- tl.full((BLOCK_SIZE,), topk_ids_numel, dtype=tl.int32),
676
- mask=mask,
677
- )
678
- elif pid == sorted_ids_blocks:
679
- offset_e = tl.arange(0, ALIGNED_NUM_EXPERTS_P1)
680
- mask_e = offset_e < num_experts + 1
681
- tl.store(
682
- cumsum_buffer_ptr + offset_e,
683
- tl.zeros((ALIGNED_NUM_EXPERTS_P1,), dtype=tl.int32),
684
- mask=mask_e,
685
- )
686
-
687
-
688
- def init_sorted_ids_and_cumsum_buffer(
689
- max_num_tokens_padded: int, topk_ids_numel: int, num_experts: int, device="cuda"
690
- ):
691
- sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
692
- cumsum_buffer = torch.empty((num_experts + 1,), dtype=torch.int32, device=device)
693
-
694
- BLOCK_SIZE = 1024
695
- sorted_ids_blocks = triton.cdiv(max_num_tokens_padded, BLOCK_SIZE)
696
- grid = (sorted_ids_blocks + 1,)
697
-
698
- init_sorted_ids_and_cumsum_buffer_kernel[grid](
699
- sorted_ids,
700
- cumsum_buffer,
701
- max_num_tokens_padded,
702
- topk_ids_numel,
703
- num_experts,
704
- BLOCK_SIZE,
705
- next_power_of_2(num_experts + 1),
706
- )
707
-
708
- return sorted_ids, cumsum_buffer
709
-
710
-
711
531
  def moe_align_block_size(
712
532
  topk_ids: torch.Tensor, block_size: int, num_experts: int
713
533
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -752,42 +572,37 @@ def moe_align_block_size(
752
572
  sorted_ids = torch.empty(
753
573
  (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
754
574
  )
755
- sorted_ids.fill_(topk_ids.numel())
756
-
757
575
  max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
758
576
  expert_ids = torch.empty(
759
577
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
760
578
  )
761
579
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
762
- if enable_moe_align_block_size_triton:
763
- moe_align_block_size_triton(
764
- topk_ids,
765
- num_experts,
766
- block_size,
767
- sorted_ids,
768
- expert_ids,
769
- num_tokens_post_pad,
770
- )
771
- else:
772
- cumsum_buffer = torch.empty(
773
- (num_experts + 1,), dtype=torch.int32, device=topk_ids.device
774
- )
775
- token_cnts_buffer = torch.empty(
776
- (num_experts + 1) * num_experts,
777
- dtype=torch.int32,
778
- device=topk_ids.device,
779
- )
780
580
 
781
- sgl_moe_align_block_size(
782
- topk_ids,
783
- num_experts,
784
- block_size,
785
- sorted_ids,
786
- expert_ids,
787
- num_tokens_post_pad,
788
- token_cnts_buffer,
789
- cumsum_buffer,
790
- )
581
+ cumsum_buffer = torch.empty(
582
+ (num_experts + 1,), dtype=torch.int32, device=topk_ids.device
583
+ )
584
+ token_cnts_buffer = torch.empty(
585
+ (num_experts + 1) * num_experts,
586
+ dtype=torch.int32,
587
+ device=topk_ids.device,
588
+ )
589
+
590
+ # Threshold based on benchmark results
591
+ fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
592
+ if not fuse_sorted_ids_padding:
593
+ sorted_ids.fill_(topk_ids.numel())
594
+
595
+ sgl_moe_align_block_size(
596
+ topk_ids,
597
+ num_experts,
598
+ block_size,
599
+ sorted_ids,
600
+ expert_ids,
601
+ num_tokens_post_pad,
602
+ token_cnts_buffer,
603
+ cumsum_buffer,
604
+ fuse_sorted_ids_padding,
605
+ )
791
606
  return sorted_ids, expert_ids, num_tokens_post_pad
792
607
 
793
608
 
@@ -1328,8 +1143,7 @@ def fused_experts(
1328
1143
  hidden_states: torch.Tensor,
1329
1144
  w1: torch.Tensor,
1330
1145
  w2: torch.Tensor,
1331
- topk_weights: torch.Tensor,
1332
- topk_ids: torch.Tensor,
1146
+ topk_output: TopKOutput,
1333
1147
  inplace: bool = False,
1334
1148
  activation: str = "silu",
1335
1149
  apply_router_weight_on_input: bool = False,
@@ -1348,7 +1162,7 @@ def fused_experts(
1348
1162
  no_combine: bool = False,
1349
1163
  routed_scaling_factor: Optional[float] = None,
1350
1164
  ):
1351
-
1165
+ topk_weights, topk_ids, _ = topk_output
1352
1166
  if inplace:
1353
1167
  assert not no_combine, "no combine + inplace makes no sense"
1354
1168
  torch.ops.sglang.inplace_fused_experts(
@@ -1517,11 +1331,7 @@ def fused_experts_impl(
1517
1331
  routed_scaling_factor: Optional[float] = None,
1518
1332
  ):
1519
1333
  padded_size = padding_size
1520
- if (
1521
- not (use_fp8_w8a8 or use_int8_w8a8)
1522
- or block_shape is not None
1523
- or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
1524
- ):
1334
+ if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
1525
1335
  padded_size = 0
1526
1336
 
1527
1337
  # Check constraints.
@@ -1719,6 +1529,17 @@ def fused_experts_impl(
1719
1529
  out_hidden_states[begin_chunk_idx:end_chunk_idx],
1720
1530
  routed_scaling_factor,
1721
1531
  )
1532
+ elif _is_hip:
1533
+ if _use_aiter:
1534
+ moe_sum(
1535
+ intermediate_cache3.view(*intermediate_cache3.shape),
1536
+ out_hidden_states[begin_chunk_idx:end_chunk_idx],
1537
+ )
1538
+ else:
1539
+ vllm_ops.moe_sum(
1540
+ intermediate_cache3.view(*intermediate_cache3.shape),
1541
+ out_hidden_states[begin_chunk_idx:end_chunk_idx],
1542
+ )
1722
1543
  else:
1723
1544
  vllm_ops.moe_sum(
1724
1545
  intermediate_cache3.view(*intermediate_cache3.shape),
@@ -1732,17 +1553,10 @@ def fused_moe(
1732
1553
  hidden_states: torch.Tensor,
1733
1554
  w1: torch.Tensor,
1734
1555
  w2: torch.Tensor,
1735
- gating_output: torch.Tensor,
1736
- topk: int,
1737
- renormalize: bool,
1556
+ topk_output: TopKOutput,
1738
1557
  inplace: bool = False,
1739
1558
  activation: str = "silu",
1740
1559
  apply_router_weight_on_input: bool = False,
1741
- use_grouped_topk: bool = False,
1742
- num_expert_group: Optional[int] = None,
1743
- num_fused_shared_experts: int = 0,
1744
- topk_group: Optional[int] = None,
1745
- custom_routing_function: Optional[Callable] = None,
1746
1560
  use_fp8_w8a8: bool = False,
1747
1561
  use_int8_w8a8: bool = False,
1748
1562
  use_int8_w8a16: bool = False,
@@ -1766,16 +1580,9 @@ def fused_moe(
1766
1580
  - hidden_states (torch.Tensor): The input tensor to the MoE layer.
1767
1581
  - w1 (torch.Tensor): The first set of expert weights.
1768
1582
  - w2 (torch.Tensor): The second set of expert weights.
1769
- - gating_output (torch.Tensor): The output of the gating operation
1770
- (before softmax).
1771
- - topk (int): The number of top-k experts to select.
1772
- - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
1583
+ - topk_output (TopKOutput): The top-k output of the experts.
1773
1584
  - inplace (bool): If True, perform the operation in-place.
1774
1585
  Defaults to False.
1775
- - num_expert_group: Optional[int]: additional parameter for grouped_topk
1776
- - topk_group: Optional[int]: additional parameter for grouped_topk
1777
- - use_grouped_topk: If True, use grouped_topk instead of fused_topk
1778
- note: Deepseek V2/V3/R1 series models use grouped_topk
1779
1586
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1780
1587
  products for w1 and w2. Defaults to False.
1781
1588
  - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
@@ -1799,28 +1606,12 @@ def fused_moe(
1799
1606
  Returns:
1800
1607
  - torch.Tensor: The output tensor after applying the MoE layer.
1801
1608
  """
1802
- # Check constraints.
1803
- assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
1804
-
1805
- topk_weights, topk_ids = select_experts(
1806
- hidden_states=hidden_states,
1807
- router_logits=gating_output,
1808
- use_grouped_topk=use_grouped_topk,
1809
- top_k=topk,
1810
- renormalize=renormalize,
1811
- topk_group=topk_group,
1812
- num_expert_group=num_expert_group,
1813
- num_fused_shared_experts=num_fused_shared_experts,
1814
- custom_routing_function=custom_routing_function,
1815
- routed_scaling_factor=routed_scaling_factor,
1816
- )
1817
1609
 
1818
1610
  return fused_experts(
1819
1611
  hidden_states,
1820
1612
  w1,
1821
1613
  w2,
1822
- topk_weights,
1823
- topk_ids,
1614
+ topk_output,
1824
1615
  inplace=inplace,
1825
1616
  activation=activation,
1826
1617
  apply_router_weight_on_input=apply_router_weight_on_input,