sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/function_call_parser.py +96 -69
  5. sglang/srt/layers/activation.py +10 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
  7. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  8. sglang/srt/layers/attention/triton_backend.py +124 -12
  9. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
  12. sglang/srt/layers/layernorm.py +1 -5
  13. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -13
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  24. sglang/srt/layers/moe/topk.py +4 -0
  25. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  46. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/fp8_kernel.py +173 -2
  48. sglang/srt/layers/rotary_embedding.py +1 -3
  49. sglang/srt/layers/sampler.py +4 -4
  50. sglang/srt/lora/backend/__init__.py +8 -0
  51. sglang/srt/lora/backend/base_backend.py +95 -0
  52. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  53. sglang/srt/lora/backend/triton_backend.py +61 -0
  54. sglang/srt/lora/lora.py +127 -112
  55. sglang/srt/lora/lora_manager.py +50 -18
  56. sglang/srt/lora/triton_ops/__init__.py +5 -0
  57. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  59. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  60. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  61. sglang/srt/model_executor/forward_batch_info.py +58 -59
  62. sglang/srt/model_executor/model_runner.py +2 -2
  63. sglang/srt/models/llama.py +8 -3
  64. sglang/srt/models/qwen2_vl.py +1 -1
  65. sglang/srt/server_args.py +13 -2
  66. sglang/srt/speculative/build_eagle_tree.py +486 -104
  67. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  68. sglang/srt/speculative/eagle_utils.py +420 -401
  69. sglang/srt/speculative/eagle_worker.py +177 -45
  70. sglang/srt/utils.py +7 -0
  71. sglang/test/runners.py +2 -0
  72. sglang/version.py +1 -1
  73. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +15 -6
  74. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +77 -38
  75. sglang/srt/layers/custom_op_util.py +0 -25
  76. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,178 @@
1
+ {
2
+ "4": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 0,
9
+ "waves_per_eu": 4,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 2
12
+ },
13
+ "8": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 64,
16
+ "BLOCK_SIZE_K": 256,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 4,
19
+ "num_stages": 0,
20
+ "waves_per_eu": 1,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 1
23
+ },
24
+ "16": {
25
+ "BLOCK_SIZE_M": 32,
26
+ "BLOCK_SIZE_N": 64,
27
+ "BLOCK_SIZE_K": 256,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 8,
30
+ "num_stages": 0,
31
+ "waves_per_eu": 2,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 2
34
+ },
35
+ "32": {
36
+ "BLOCK_SIZE_M": 32,
37
+ "BLOCK_SIZE_N": 64,
38
+ "BLOCK_SIZE_K": 256,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 8,
41
+ "num_stages": 0,
42
+ "waves_per_eu": 1,
43
+ "matrix_instr_nonkdim": 16,
44
+ "kpack": 2
45
+ },
46
+ "64": {
47
+ "BLOCK_SIZE_M": 32,
48
+ "BLOCK_SIZE_N": 64,
49
+ "BLOCK_SIZE_K": 256,
50
+ "GROUP_SIZE_M": 1,
51
+ "num_warps": 4,
52
+ "num_stages": 0,
53
+ "waves_per_eu": 2,
54
+ "matrix_instr_nonkdim": 16,
55
+ "kpack": 2
56
+ },
57
+ "128": {
58
+ "BLOCK_SIZE_M": 64,
59
+ "BLOCK_SIZE_N": 128,
60
+ "BLOCK_SIZE_K": 128,
61
+ "GROUP_SIZE_M": 1,
62
+ "num_warps": 8,
63
+ "num_stages": 0,
64
+ "waves_per_eu": 0,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 1
67
+ },
68
+ "256": {
69
+ "BLOCK_SIZE_M": 64,
70
+ "BLOCK_SIZE_N": 128,
71
+ "BLOCK_SIZE_K": 128,
72
+ "GROUP_SIZE_M": 1,
73
+ "num_warps": 8,
74
+ "num_stages": 0,
75
+ "waves_per_eu": 0,
76
+ "matrix_instr_nonkdim": 16,
77
+ "kpack": 1
78
+ },
79
+ "512": {
80
+ "BLOCK_SIZE_M": 64,
81
+ "BLOCK_SIZE_N": 128,
82
+ "BLOCK_SIZE_K": 128,
83
+ "GROUP_SIZE_M": 1,
84
+ "num_warps": 8,
85
+ "num_stages": 0,
86
+ "waves_per_eu": 0,
87
+ "matrix_instr_nonkdim": 16,
88
+ "kpack": 2
89
+ },
90
+ "1024": {
91
+ "BLOCK_SIZE_M": 128,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 8,
96
+ "num_stages": 0,
97
+ "waves_per_eu": 4,
98
+ "matrix_instr_nonkdim": 16,
99
+ "kpack": 2
100
+ },
101
+ "2048": {
102
+ "BLOCK_SIZE_M": 128,
103
+ "BLOCK_SIZE_N": 128,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 1,
106
+ "num_warps": 8,
107
+ "num_stages": 0,
108
+ "waves_per_eu": 2,
109
+ "matrix_instr_nonkdim": 16,
110
+ "kpack": 2
111
+ },
112
+ "4096": {
113
+ "BLOCK_SIZE_M": 128,
114
+ "BLOCK_SIZE_N": 128,
115
+ "BLOCK_SIZE_K": 128,
116
+ "GROUP_SIZE_M": 1,
117
+ "num_warps": 8,
118
+ "num_stages": 0,
119
+ "waves_per_eu": 2,
120
+ "matrix_instr_nonkdim": 16,
121
+ "kpack": 2
122
+ },
123
+ "8192": {
124
+ "BLOCK_SIZE_M": 256,
125
+ "BLOCK_SIZE_N": 256,
126
+ "BLOCK_SIZE_K": 64,
127
+ "GROUP_SIZE_M": 1,
128
+ "num_warps": 8,
129
+ "num_stages": 0,
130
+ "waves_per_eu": 2,
131
+ "matrix_instr_nonkdim": 16,
132
+ "kpack": 1
133
+ },
134
+ "16384": {
135
+ "BLOCK_SIZE_M": 256,
136
+ "BLOCK_SIZE_N": 256,
137
+ "BLOCK_SIZE_K": 64,
138
+ "GROUP_SIZE_M": 1,
139
+ "num_warps": 8,
140
+ "num_stages": 0,
141
+ "waves_per_eu": 1,
142
+ "matrix_instr_nonkdim": 16,
143
+ "kpack": 1
144
+ },
145
+ "32768": {
146
+ "BLOCK_SIZE_M": 256,
147
+ "BLOCK_SIZE_N": 256,
148
+ "BLOCK_SIZE_K": 64,
149
+ "GROUP_SIZE_M": 1,
150
+ "num_warps": 8,
151
+ "num_stages": 0,
152
+ "waves_per_eu": 0,
153
+ "matrix_instr_nonkdim": 16,
154
+ "kpack": 1
155
+ },
156
+ "65536": {
157
+ "BLOCK_SIZE_M": 256,
158
+ "BLOCK_SIZE_N": 256,
159
+ "BLOCK_SIZE_K": 64,
160
+ "GROUP_SIZE_M": 1,
161
+ "num_warps": 8,
162
+ "num_stages": 0,
163
+ "waves_per_eu": 1,
164
+ "matrix_instr_nonkdim": 16,
165
+ "kpack": 1
166
+ },
167
+ "131072": {
168
+ "BLOCK_SIZE_M": 256,
169
+ "BLOCK_SIZE_N": 128,
170
+ "BLOCK_SIZE_K": 64,
171
+ "GROUP_SIZE_M": 1,
172
+ "num_warps": 4,
173
+ "num_stages": 0,
174
+ "waves_per_eu": 2,
175
+ "matrix_instr_nonkdim": 16,
176
+ "kpack": 2
177
+ }
178
+ }
@@ -0,0 +1,200 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 16,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 2,
8
+ "num_stages": 0,
9
+ "waves_per_eu": 0,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 2
12
+ },
13
+ "2": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 64,
16
+ "BLOCK_SIZE_K": 32,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 4,
19
+ "num_stages": 0,
20
+ "waves_per_eu": 0,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 1
23
+ },
24
+ "4": {
25
+ "BLOCK_SIZE_M": 16,
26
+ "BLOCK_SIZE_N": 32,
27
+ "BLOCK_SIZE_K": 128,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 4,
30
+ "num_stages": 0,
31
+ "waves_per_eu": 0,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 1
34
+ },
35
+ "8": {
36
+ "BLOCK_SIZE_M": 16,
37
+ "BLOCK_SIZE_N": 32,
38
+ "BLOCK_SIZE_K": 256,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 2,
41
+ "num_stages": 0,
42
+ "waves_per_eu": 0,
43
+ "matrix_instr_nonkdim": 16,
44
+ "kpack": 1
45
+ },
46
+ "16": {
47
+ "BLOCK_SIZE_M": 16,
48
+ "BLOCK_SIZE_N": 16,
49
+ "BLOCK_SIZE_K": 256,
50
+ "GROUP_SIZE_M": 1,
51
+ "num_warps": 4,
52
+ "num_stages": 0,
53
+ "waves_per_eu": 0,
54
+ "matrix_instr_nonkdim": 16,
55
+ "kpack": 2
56
+ },
57
+ "24": {
58
+ "BLOCK_SIZE_M": 32,
59
+ "BLOCK_SIZE_N": 32,
60
+ "BLOCK_SIZE_K": 128,
61
+ "GROUP_SIZE_M": 1,
62
+ "num_warps": 8,
63
+ "num_stages": 0,
64
+ "waves_per_eu": 0,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 1
67
+ },
68
+ "32": {
69
+ "BLOCK_SIZE_M": 16,
70
+ "BLOCK_SIZE_N": 32,
71
+ "BLOCK_SIZE_K": 128,
72
+ "GROUP_SIZE_M": 4,
73
+ "num_warps": 2,
74
+ "num_stages": 0,
75
+ "waves_per_eu": 0,
76
+ "matrix_instr_nonkdim": 16,
77
+ "kpack": 2
78
+ },
79
+ "48": {
80
+ "BLOCK_SIZE_M": 16,
81
+ "BLOCK_SIZE_N": 32,
82
+ "BLOCK_SIZE_K": 128,
83
+ "GROUP_SIZE_M": 4,
84
+ "num_warps": 2,
85
+ "num_stages": 0,
86
+ "waves_per_eu": 0,
87
+ "matrix_instr_nonkdim": 16,
88
+ "kpack": 1
89
+ },
90
+ "64": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 32,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 4,
95
+ "num_warps": 4,
96
+ "num_stages": 0,
97
+ "waves_per_eu": 0,
98
+ "matrix_instr_nonkdim": 16,
99
+ "kpack": 2
100
+ },
101
+ "96": {
102
+ "BLOCK_SIZE_M": 32,
103
+ "BLOCK_SIZE_N": 32,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 4,
106
+ "num_warps": 4,
107
+ "num_stages": 0,
108
+ "waves_per_eu": 0,
109
+ "matrix_instr_nonkdim": 16,
110
+ "kpack": 2
111
+ },
112
+ "128": {
113
+ "BLOCK_SIZE_M": 64,
114
+ "BLOCK_SIZE_N": 64,
115
+ "BLOCK_SIZE_K": 64,
116
+ "GROUP_SIZE_M": 4,
117
+ "num_warps": 8,
118
+ "num_stages": 0,
119
+ "waves_per_eu": 0,
120
+ "matrix_instr_nonkdim": 16,
121
+ "kpack": 1
122
+ },
123
+ "256": {
124
+ "BLOCK_SIZE_M": 128,
125
+ "BLOCK_SIZE_N": 128,
126
+ "BLOCK_SIZE_K": 64,
127
+ "GROUP_SIZE_M": 4,
128
+ "num_warps": 8,
129
+ "num_stages": 0,
130
+ "waves_per_eu": 0,
131
+ "matrix_instr_nonkdim": 32,
132
+ "kpack": 2
133
+ },
134
+ "512": {
135
+ "BLOCK_SIZE_M": 128,
136
+ "BLOCK_SIZE_N": 128,
137
+ "BLOCK_SIZE_K": 64,
138
+ "GROUP_SIZE_M": 1,
139
+ "num_warps": 8,
140
+ "num_stages": 0,
141
+ "waves_per_eu": 0,
142
+ "matrix_instr_nonkdim": 16,
143
+ "kpack": 1
144
+ },
145
+ "1024": {
146
+ "BLOCK_SIZE_M": 128,
147
+ "BLOCK_SIZE_N": 128,
148
+ "BLOCK_SIZE_K": 64,
149
+ "GROUP_SIZE_M": 1,
150
+ "num_warps": 8,
151
+ "num_stages": 0,
152
+ "waves_per_eu": 0,
153
+ "matrix_instr_nonkdim": 16,
154
+ "kpack": 1
155
+ },
156
+ "1536": {
157
+ "BLOCK_SIZE_M": 128,
158
+ "BLOCK_SIZE_N": 128,
159
+ "BLOCK_SIZE_K": 64,
160
+ "GROUP_SIZE_M": 1,
161
+ "num_warps": 8,
162
+ "num_stages": 0,
163
+ "waves_per_eu": 0,
164
+ "matrix_instr_nonkdim": 16,
165
+ "kpack": 2
166
+ },
167
+ "2048": {
168
+ "BLOCK_SIZE_M": 128,
169
+ "BLOCK_SIZE_N": 128,
170
+ "BLOCK_SIZE_K": 64,
171
+ "GROUP_SIZE_M": 1,
172
+ "num_warps": 8,
173
+ "num_stages": 0,
174
+ "waves_per_eu": 0,
175
+ "matrix_instr_nonkdim": 16,
176
+ "kpack": 1
177
+ },
178
+ "3072": {
179
+ "BLOCK_SIZE_M": 128,
180
+ "BLOCK_SIZE_N": 128,
181
+ "BLOCK_SIZE_K": 64,
182
+ "GROUP_SIZE_M": 1,
183
+ "num_warps": 8,
184
+ "num_stages": 0,
185
+ "waves_per_eu": 0,
186
+ "matrix_instr_nonkdim": 16,
187
+ "kpack": 2
188
+ },
189
+ "4096": {
190
+ "BLOCK_SIZE_M": 128,
191
+ "BLOCK_SIZE_N": 128,
192
+ "BLOCK_SIZE_K": 64,
193
+ "GROUP_SIZE_M": 1,
194
+ "num_warps": 8,
195
+ "num_stages": 0,
196
+ "waves_per_eu": 0,
197
+ "matrix_instr_nonkdim": 16,
198
+ "kpack": 1
199
+ }
200
+ }
@@ -0,0 +1,175 @@
1
+ {
2
+ "4": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 0,
9
+ "waves_per_eu": 4,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 2
12
+ },
13
+ "8": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 64,
16
+ "BLOCK_SIZE_K": 256,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 4,
19
+ "num_stages": 0,
20
+ "waves_per_eu": 1,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 1
23
+ },
24
+ "16": {
25
+ "BLOCK_SIZE_M": 32,
26
+ "BLOCK_SIZE_N": 64,
27
+ "BLOCK_SIZE_K": 256,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 8,
30
+ "num_stages": 0,
31
+ "waves_per_eu": 2,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 2
34
+ },
35
+ "32": {
36
+ "BLOCK_SIZE_M": 32,
37
+ "BLOCK_SIZE_N": 64,
38
+ "BLOCK_SIZE_K": 256,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 8,
41
+ "num_stages": 0,
42
+ "waves_per_eu": 1,
43
+ "matrix_instr_nonkdim": 16,
44
+ "kpack": 2
45
+ },
46
+ "64": {
47
+ "BLOCK_SIZE_M": 32,
48
+ "BLOCK_SIZE_N": 64,
49
+ "BLOCK_SIZE_K": 256,
50
+ "GROUP_SIZE_M": 1,
51
+ "num_warps": 4,
52
+ "num_stages": 0,
53
+ "waves_per_eu": 2,
54
+ "matrix_instr_nonkdim": 16,
55
+ "kpack": 2
56
+ },
57
+ "128": {
58
+ "BLOCK_SIZE_M": 16,
59
+ "BLOCK_SIZE_N": 64,
60
+ "BLOCK_SIZE_K": 256,
61
+ "GROUP_SIZE_M": 1,
62
+ "num_warps": 4,
63
+ "num_stages": 0,
64
+ "waves_per_eu": 1,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 1
67
+ },
68
+ "256": {
69
+ "BLOCK_SIZE_M": 128,
70
+ "BLOCK_SIZE_N": 256,
71
+ "BLOCK_SIZE_K": 128,
72
+ "GROUP_SIZE_M": 32,
73
+ "num_warps": 8,
74
+ "num_stages": 4
75
+ },
76
+ "512": {
77
+ "BLOCK_SIZE_M": 64,
78
+ "BLOCK_SIZE_N": 64,
79
+ "BLOCK_SIZE_K": 256,
80
+ "GROUP_SIZE_M": 1,
81
+ "num_warps": 4,
82
+ "num_stages": 0,
83
+ "waves_per_eu": 2,
84
+ "matrix_instr_nonkdim": 16,
85
+ "kpack": 2
86
+ },
87
+ "1024": {
88
+ "BLOCK_SIZE_M": 128,
89
+ "BLOCK_SIZE_N": 128,
90
+ "BLOCK_SIZE_K": 128,
91
+ "GROUP_SIZE_M": 1,
92
+ "num_warps": 8,
93
+ "num_stages": 0,
94
+ "waves_per_eu": 4,
95
+ "matrix_instr_nonkdim": 16,
96
+ "kpack": 2
97
+ },
98
+ "2048": {
99
+ "BLOCK_SIZE_M": 128,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 0,
105
+ "waves_per_eu": 2,
106
+ "matrix_instr_nonkdim": 16,
107
+ "kpack": 2
108
+ },
109
+ "4096": {
110
+ "BLOCK_SIZE_M": 128,
111
+ "BLOCK_SIZE_N": 128,
112
+ "BLOCK_SIZE_K": 128,
113
+ "GROUP_SIZE_M": 1,
114
+ "num_warps": 8,
115
+ "num_stages": 0,
116
+ "waves_per_eu": 2,
117
+ "matrix_instr_nonkdim": 16,
118
+ "kpack": 2
119
+ },
120
+ "8192": {
121
+ "BLOCK_SIZE_M": 256,
122
+ "BLOCK_SIZE_N": 256,
123
+ "BLOCK_SIZE_K": 64,
124
+ "GROUP_SIZE_M": 1,
125
+ "num_warps": 8,
126
+ "num_stages": 0,
127
+ "waves_per_eu": 2,
128
+ "matrix_instr_nonkdim": 16,
129
+ "kpack": 1
130
+ },
131
+ "16384": {
132
+ "BLOCK_SIZE_M": 256,
133
+ "BLOCK_SIZE_N": 256,
134
+ "BLOCK_SIZE_K": 64,
135
+ "GROUP_SIZE_M": 1,
136
+ "num_warps": 8,
137
+ "num_stages": 0,
138
+ "waves_per_eu": 1,
139
+ "matrix_instr_nonkdim": 16,
140
+ "kpack": 1
141
+ },
142
+ "32768": {
143
+ "BLOCK_SIZE_M": 256,
144
+ "BLOCK_SIZE_N": 256,
145
+ "BLOCK_SIZE_K": 64,
146
+ "GROUP_SIZE_M": 1,
147
+ "num_warps": 8,
148
+ "num_stages": 0,
149
+ "waves_per_eu": 0,
150
+ "matrix_instr_nonkdim": 16,
151
+ "kpack": 1
152
+ },
153
+ "65536": {
154
+ "BLOCK_SIZE_M": 256,
155
+ "BLOCK_SIZE_N": 256,
156
+ "BLOCK_SIZE_K": 64,
157
+ "GROUP_SIZE_M": 1,
158
+ "num_warps": 8,
159
+ "num_stages": 0,
160
+ "waves_per_eu": 1,
161
+ "matrix_instr_nonkdim": 16,
162
+ "kpack": 1
163
+ },
164
+ "131072": {
165
+ "BLOCK_SIZE_M": 256,
166
+ "BLOCK_SIZE_N": 128,
167
+ "BLOCK_SIZE_K": 64,
168
+ "GROUP_SIZE_M": 1,
169
+ "num_warps": 4,
170
+ "num_stages": 0,
171
+ "waves_per_eu": 2,
172
+ "matrix_instr_nonkdim": 16,
173
+ "kpack": 2
174
+ }
175
+ }
@@ -15,18 +15,10 @@ from vllm import _custom_ops as ops
15
15
 
16
16
  from sglang.srt.layers.moe.topk import select_experts
17
17
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
18
- from sglang.srt.utils import (
19
- direct_register_custom_op,
20
- get_device_name,
21
- is_cuda_available,
22
- is_hip,
23
- )
18
+ from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
24
19
 
25
- is_cuda = is_cuda_available()
26
20
  is_hip_flag = is_hip()
27
- if is_cuda:
28
- from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
29
-
21
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
30
22
 
31
23
  logger = logging.getLogger(__name__)
32
24
  padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
@@ -415,7 +407,7 @@ def moe_align_block_size(
415
407
  )
416
408
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
417
409
  if num_experts >= 224:
418
- if enable_moe_align_block_size_triton or is_hip_flag:
410
+ if enable_moe_align_block_size_triton:
419
411
  moe_align_block_size_triton(
420
412
  topk_ids,
421
413
  num_experts,
@@ -425,12 +417,12 @@ def moe_align_block_size(
425
417
  num_tokens_post_pad,
426
418
  )
427
419
  else:
428
- token_cnts_buffer = torch.empty(
420
+ token_cnts_buffer = torch.zeros(
429
421
  (num_experts + 1) * num_experts,
430
422
  dtype=torch.int32,
431
423
  device=topk_ids.device,
432
424
  )
433
- cumsum_buffer = torch.empty(
425
+ cumsum_buffer = torch.zeros(
434
426
  num_experts + 1, dtype=torch.int32, device=topk_ids.device
435
427
  )
436
428
 
@@ -5,14 +5,13 @@ from enum import Enum
5
5
  from typing import Callable, List, Optional, Tuple
6
6
 
7
7
  import torch
8
- from vllm.model_executor.custom_op import CustomOp
9
8
 
9
+ from sglang.srt.custom_op import CustomOp
10
10
  from sglang.srt.distributed import (
11
11
  get_tensor_model_parallel_rank,
12
12
  get_tensor_model_parallel_world_size,
13
13
  tensor_model_parallel_all_reduce,
14
14
  )
15
- from sglang.srt.layers.custom_op_util import register_custom_op
16
15
  from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
17
16
  from sglang.srt.layers.moe.topk import select_experts
18
17
  from sglang.srt.layers.quantization.base_config import (
@@ -67,7 +66,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
67
66
  raise NotImplementedError
68
67
 
69
68
 
70
- @register_custom_op("sglang_unquantized_fused_moe")
71
69
  class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
72
70
  """MoE method without quantization."""
73
71
 
@@ -17,6 +17,8 @@ from typing import Callable, Optional
17
17
  import torch
18
18
  import torch.nn.functional as F
19
19
 
20
+ from sglang.srt.utils import get_compiler_backend
21
+
20
22
 
21
23
  def fused_topk_native(
22
24
  hidden_states: torch.Tensor,
@@ -74,6 +76,7 @@ def fused_topk(
74
76
 
75
77
 
76
78
  # This is used by the Deepseek-V2 model
79
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
77
80
  def grouped_topk(
78
81
  hidden_states: torch.Tensor,
79
82
  gating_output: torch.Tensor,
@@ -108,6 +111,7 @@ def grouped_topk(
108
111
  return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
109
112
 
110
113
 
114
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
111
115
  def biased_grouped_topk(
112
116
  hidden_states: torch.Tensor,
113
117
  gating_output: torch.Tensor,