sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +330 -200
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +12 -5
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +25 -13
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +1 -0
- sglang/srt/layers/radix_attention.py +13 -1
- sglang/srt/layers/rotary_embedding.py +12 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +48 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +1 -0
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json
ADDED
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 64,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 64,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 64,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 32,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 64,
|
86
|
+
"GROUP_SIZE_M": 32,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 64,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 64,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 64,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 32,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 2
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 64,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|
sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json
ADDED
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 64,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 64,
|
30
|
+
"GROUP_SIZE_M": 16,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 64,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 64,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 64,
|
86
|
+
"GROUP_SIZE_M": 32,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 64,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 64,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 64,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 64,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 64,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|
@@ -342,6 +342,7 @@ def fused_moe_kernel(
|
|
342
342
|
use_fp8_w8a8: tl.constexpr,
|
343
343
|
use_int8_w8a8: tl.constexpr,
|
344
344
|
use_int8_w8a16: tl.constexpr,
|
345
|
+
per_channel_quant: tl.constexpr,
|
345
346
|
even_Ks: tl.constexpr,
|
346
347
|
):
|
347
348
|
"""
|
@@ -416,20 +417,7 @@ def fused_moe_kernel(
|
|
416
417
|
)
|
417
418
|
b_scale = tl.load(b_scale_ptrs)
|
418
419
|
|
419
|
-
if use_fp8_w8a8:
|
420
|
-
# block-wise
|
421
|
-
if group_k > 0 and group_n > 0:
|
422
|
-
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
423
|
-
offs_bsn = offs_bn // group_n
|
424
|
-
b_scale_ptrs = (
|
425
|
-
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
426
|
-
)
|
427
|
-
# tensor-wise
|
428
|
-
else:
|
429
|
-
a_scale = tl.load(a_scale_ptr)
|
430
|
-
b_scale = tl.load(b_scale_ptr + off_experts)
|
431
|
-
|
432
|
-
if use_int8_w8a8:
|
420
|
+
if use_fp8_w8a8 or use_int8_w8a8:
|
433
421
|
# block-wise
|
434
422
|
if group_k > 0 and group_n > 0:
|
435
423
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
@@ -438,8 +426,7 @@ def fused_moe_kernel(
|
|
438
426
|
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
439
427
|
)
|
440
428
|
# channel-wise
|
441
|
-
|
442
|
-
# Load per-column scale for weights
|
429
|
+
elif per_channel_quant:
|
443
430
|
b_scale_ptrs = (
|
444
431
|
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
445
432
|
)
|
@@ -447,6 +434,10 @@ def fused_moe_kernel(
|
|
447
434
|
# Load per-token scale for activations
|
448
435
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
449
436
|
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
437
|
+
# tensor-wise
|
438
|
+
else:
|
439
|
+
a_scale = tl.load(a_scale_ptr)
|
440
|
+
b_scale = tl.load(b_scale_ptr + off_experts)
|
450
441
|
|
451
442
|
# -----------------------------------------------------------
|
452
443
|
# Iterate to compute a block of the C matrix.
|
@@ -711,12 +702,12 @@ def moe_align_block_size(
|
|
711
702
|
num_tokens_post_pad,
|
712
703
|
)
|
713
704
|
else:
|
714
|
-
token_cnts_buffer = torch.
|
705
|
+
token_cnts_buffer = torch.empty(
|
715
706
|
(num_experts + 1) * num_experts,
|
716
707
|
dtype=torch.int32,
|
717
708
|
device=topk_ids.device,
|
718
709
|
)
|
719
|
-
cumsum_buffer = torch.
|
710
|
+
cumsum_buffer = torch.empty(
|
720
711
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
721
712
|
)
|
722
713
|
|
@@ -753,6 +744,7 @@ def invoke_fused_moe_kernel(
|
|
753
744
|
use_int8_w8a8: bool,
|
754
745
|
use_int8_w8a16: bool,
|
755
746
|
use_int4_w4a16: bool,
|
747
|
+
per_channel_quant: bool,
|
756
748
|
block_shape: Optional[List[int]] = None,
|
757
749
|
no_combine: bool = False,
|
758
750
|
) -> None:
|
@@ -765,6 +757,8 @@ def invoke_fused_moe_kernel(
|
|
765
757
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
766
758
|
sglang_per_token_group_quant_fp8,
|
767
759
|
)
|
760
|
+
else:
|
761
|
+
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
768
762
|
|
769
763
|
assert topk_weights.stride(1) == 1
|
770
764
|
assert sorted_token_ids.stride(0) == 1
|
@@ -775,10 +769,15 @@ def invoke_fused_moe_kernel(
|
|
775
769
|
if block_shape is None:
|
776
770
|
# activation tensor-wise fp8 quantization, dynamic or static
|
777
771
|
padded_size = padding_size
|
772
|
+
# activations apply per-token quantization when weights apply per-channel quantization by default
|
778
773
|
if _is_cuda:
|
779
|
-
A, A_scale = sgl_scaled_fp8_quant(
|
774
|
+
A, A_scale = sgl_scaled_fp8_quant(
|
775
|
+
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
776
|
+
)
|
780
777
|
else:
|
781
|
-
A, A_scale = vllm_ops.scaled_fp8_quant(
|
778
|
+
A, A_scale = vllm_ops.scaled_fp8_quant(
|
779
|
+
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
780
|
+
)
|
782
781
|
else:
|
783
782
|
# activation block-wise fp8 quantization
|
784
783
|
assert len(block_shape) == 2
|
@@ -794,6 +793,9 @@ def invoke_fused_moe_kernel(
|
|
794
793
|
assert B_scale is not None
|
795
794
|
if block_shape is None:
|
796
795
|
# activation channel-wise int8 quantization
|
796
|
+
assert (
|
797
|
+
per_channel_quant
|
798
|
+
), "int8 quantization only supports channel-wise quantization except for block-wise quantization"
|
797
799
|
A, A_scale = per_token_quant_int8(A)
|
798
800
|
else:
|
799
801
|
# activation block-wise int8 quantization
|
@@ -902,6 +904,7 @@ def invoke_fused_moe_kernel(
|
|
902
904
|
use_fp8_w8a8=use_fp8_w8a8,
|
903
905
|
use_int8_w8a8=use_int8_w8a8,
|
904
906
|
use_int8_w8a16=use_int8_w8a16,
|
907
|
+
per_channel_quant=per_channel_quant,
|
905
908
|
even_Ks=even_Ks,
|
906
909
|
**config,
|
907
910
|
)
|
@@ -953,7 +956,7 @@ def get_moe_configs(
|
|
953
956
|
logger.warning(
|
954
957
|
(
|
955
958
|
"Using default MoE config. Performance might be sub-optimal! "
|
956
|
-
"Config file not found at %s"
|
959
|
+
"Config file not found at %s, you can tune the config with https://github.com/sgl-project/sglang/blob/main/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py."
|
957
960
|
),
|
958
961
|
config_file_path,
|
959
962
|
)
|
@@ -1084,6 +1087,7 @@ def inplace_fused_experts(
|
|
1084
1087
|
use_int8_w8a8: bool = False,
|
1085
1088
|
use_int8_w8a16: bool = False,
|
1086
1089
|
use_int4_w4a16: bool = False,
|
1090
|
+
per_channel_quant: bool = False,
|
1087
1091
|
w1_scale: Optional[torch.Tensor] = None,
|
1088
1092
|
w2_scale: Optional[torch.Tensor] = None,
|
1089
1093
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1105,6 +1109,7 @@ def inplace_fused_experts(
|
|
1105
1109
|
use_int8_w8a8,
|
1106
1110
|
use_int8_w8a16,
|
1107
1111
|
use_int4_w4a16,
|
1112
|
+
per_channel_quant,
|
1108
1113
|
w1_scale,
|
1109
1114
|
w2_scale,
|
1110
1115
|
w1_zp,
|
@@ -1127,6 +1132,7 @@ def inplace_fused_experts_fake(
|
|
1127
1132
|
use_int8_w8a8: bool = False,
|
1128
1133
|
use_int8_w8a16: bool = False,
|
1129
1134
|
use_int4_w4a16: bool = False,
|
1135
|
+
per_channel_quant: bool = False,
|
1130
1136
|
w1_scale: Optional[torch.Tensor] = None,
|
1131
1137
|
w2_scale: Optional[torch.Tensor] = None,
|
1132
1138
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1158,6 +1164,7 @@ def outplace_fused_experts(
|
|
1158
1164
|
use_int8_w8a8: bool = False,
|
1159
1165
|
use_int8_w8a16: bool = False,
|
1160
1166
|
use_int4_w4a16: bool = False,
|
1167
|
+
per_channel_quant: bool = False,
|
1161
1168
|
w1_scale: Optional[torch.Tensor] = None,
|
1162
1169
|
w2_scale: Optional[torch.Tensor] = None,
|
1163
1170
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1180,6 +1187,7 @@ def outplace_fused_experts(
|
|
1180
1187
|
use_int8_w8a8,
|
1181
1188
|
use_int8_w8a16,
|
1182
1189
|
use_int4_w4a16,
|
1190
|
+
per_channel_quant,
|
1183
1191
|
w1_scale,
|
1184
1192
|
w2_scale,
|
1185
1193
|
w1_zp,
|
@@ -1203,6 +1211,7 @@ def outplace_fused_experts_fake(
|
|
1203
1211
|
use_int8_w8a8: bool = False,
|
1204
1212
|
use_int8_w8a16: bool = False,
|
1205
1213
|
use_int4_w4a16: bool = False,
|
1214
|
+
per_channel_quant: bool = False,
|
1206
1215
|
w1_scale: Optional[torch.Tensor] = None,
|
1207
1216
|
w2_scale: Optional[torch.Tensor] = None,
|
1208
1217
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1236,6 +1245,7 @@ def fused_experts(
|
|
1236
1245
|
use_int8_w8a8: bool = False,
|
1237
1246
|
use_int8_w8a16: bool = False,
|
1238
1247
|
use_int4_w4a16: bool = False,
|
1248
|
+
per_channel_quant: bool = False,
|
1239
1249
|
w1_scale: Optional[torch.Tensor] = None,
|
1240
1250
|
w2_scale: Optional[torch.Tensor] = None,
|
1241
1251
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1259,6 +1269,7 @@ def fused_experts(
|
|
1259
1269
|
use_int8_w8a8,
|
1260
1270
|
use_int8_w8a16,
|
1261
1271
|
use_int4_w4a16,
|
1272
|
+
per_channel_quant,
|
1262
1273
|
w1_scale,
|
1263
1274
|
w2_scale,
|
1264
1275
|
w1_zp,
|
@@ -1281,6 +1292,7 @@ def fused_experts(
|
|
1281
1292
|
use_int8_w8a8,
|
1282
1293
|
use_int8_w8a16,
|
1283
1294
|
use_int4_w4a16,
|
1295
|
+
per_channel_quant,
|
1284
1296
|
w1_scale,
|
1285
1297
|
w2_scale,
|
1286
1298
|
w1_zp,
|
@@ -1305,6 +1317,7 @@ def fused_experts_impl(
|
|
1305
1317
|
use_int8_w8a8: bool = False,
|
1306
1318
|
use_int8_w8a16: bool = False,
|
1307
1319
|
use_int4_w4a16: bool = False,
|
1320
|
+
per_channel_quant: bool = False,
|
1308
1321
|
w1_scale: Optional[torch.Tensor] = None,
|
1309
1322
|
w2_scale: Optional[torch.Tensor] = None,
|
1310
1323
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1441,6 +1454,7 @@ def fused_experts_impl(
|
|
1441
1454
|
use_int8_w8a8=use_int8_w8a8,
|
1442
1455
|
use_int8_w8a16=use_int8_w8a16,
|
1443
1456
|
use_int4_w4a16=use_int4_w4a16,
|
1457
|
+
per_channel_quant=per_channel_quant,
|
1444
1458
|
block_shape=block_shape,
|
1445
1459
|
)
|
1446
1460
|
if activation == "silu":
|
@@ -1484,6 +1498,7 @@ def fused_experts_impl(
|
|
1484
1498
|
use_int8_w8a8=use_int8_w8a8,
|
1485
1499
|
use_int8_w8a16=use_int8_w8a16,
|
1486
1500
|
use_int4_w4a16=use_int4_w4a16,
|
1501
|
+
per_channel_quant=per_channel_quant,
|
1487
1502
|
block_shape=block_shape,
|
1488
1503
|
)
|
1489
1504
|
|
@@ -1530,6 +1545,7 @@ def fused_moe(
|
|
1530
1545
|
use_int8_w8a8: bool = False,
|
1531
1546
|
use_int8_w8a16: bool = False,
|
1532
1547
|
use_int4_w4a16: bool = False,
|
1548
|
+
per_channel_quant: bool = False,
|
1533
1549
|
w1_scale: Optional[torch.Tensor] = None,
|
1534
1550
|
w2_scale: Optional[torch.Tensor] = None,
|
1535
1551
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1606,6 +1622,7 @@ def fused_moe(
|
|
1606
1622
|
use_int8_w8a8=use_int8_w8a8,
|
1607
1623
|
use_int8_w8a16=use_int8_w8a16,
|
1608
1624
|
use_int4_w4a16=use_int4_w4a16,
|
1625
|
+
per_channel_quant=per_channel_quant,
|
1609
1626
|
w1_scale=w1_scale,
|
1610
1627
|
w2_scale=w2_scale,
|
1611
1628
|
w1_zp=w1_zp,
|
sglang/srt/layers/moe/router.py
CHANGED
@@ -5,6 +5,9 @@ import triton
|
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
7
|
from sglang.srt.layers.moe.topk import fused_topk
|
8
|
+
from sglang.srt.utils import is_hip
|
9
|
+
|
10
|
+
_is_hip = is_hip()
|
8
11
|
|
9
12
|
|
10
13
|
@triton.jit
|
@@ -116,10 +119,13 @@ def fused_moe_router_impl(
|
|
116
119
|
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
117
120
|
|
118
121
|
grid = lambda meta: (bs,)
|
122
|
+
|
123
|
+
min_num_warps = 16 if _is_hip else 32
|
124
|
+
|
119
125
|
config = {
|
120
126
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
121
127
|
"num_warps": max(
|
122
|
-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)),
|
128
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
|
123
129
|
),
|
124
130
|
}
|
125
131
|
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -12,6 +12,7 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
+
import math
|
15
16
|
import os
|
16
17
|
from typing import Callable, Optional
|
17
18
|
|
@@ -25,6 +26,8 @@ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
|
25
26
|
_is_cuda = is_cuda()
|
26
27
|
_is_hip = is_hip()
|
27
28
|
|
29
|
+
if _is_cuda:
|
30
|
+
from sgl_kernel import moe_fused_gate
|
28
31
|
|
29
32
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
30
33
|
|
@@ -209,6 +212,10 @@ def biased_grouped_topk_impl(
|
|
209
212
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
210
213
|
|
211
214
|
|
215
|
+
def is_power_of_two(n):
|
216
|
+
return n > 0 and math.log2(n).is_integer()
|
217
|
+
|
218
|
+
|
212
219
|
def biased_grouped_topk(
|
213
220
|
hidden_states: torch.Tensor,
|
214
221
|
gating_output: torch.Tensor,
|
@@ -220,23 +227,37 @@ def biased_grouped_topk(
|
|
220
227
|
compiled: bool = True,
|
221
228
|
n_share_experts_fusion: int = 0,
|
222
229
|
):
|
223
|
-
|
224
|
-
|
225
|
-
|
230
|
+
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
|
231
|
+
if (
|
232
|
+
_is_cuda
|
233
|
+
and n_share_experts_fusion == 0
|
234
|
+
and is_power_of_two(correction_bias.shape[0])
|
235
|
+
):
|
236
|
+
return moe_fused_gate(
|
237
|
+
gating_output,
|
238
|
+
correction_bias,
|
239
|
+
num_expert_group,
|
240
|
+
topk_group,
|
241
|
+
topk,
|
242
|
+
)
|
243
|
+
else:
|
244
|
+
biased_grouped_topk_fn = (
|
245
|
+
torch.compile(
|
246
|
+
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
|
247
|
+
)
|
248
|
+
if compiled
|
249
|
+
else biased_grouped_topk_impl
|
250
|
+
)
|
251
|
+
return biased_grouped_topk_fn(
|
252
|
+
hidden_states,
|
253
|
+
gating_output,
|
254
|
+
correction_bias,
|
255
|
+
topk,
|
256
|
+
renormalize,
|
257
|
+
num_expert_group,
|
258
|
+
topk_group,
|
259
|
+
n_share_experts_fusion=n_share_experts_fusion,
|
226
260
|
)
|
227
|
-
if compiled
|
228
|
-
else biased_grouped_topk_impl
|
229
|
-
)
|
230
|
-
return biased_grouped_topk_fn(
|
231
|
-
hidden_states,
|
232
|
-
gating_output,
|
233
|
-
correction_bias,
|
234
|
-
topk,
|
235
|
-
renormalize,
|
236
|
-
num_expert_group,
|
237
|
-
topk_group,
|
238
|
-
n_share_experts_fusion=n_share_experts_fusion,
|
239
|
-
)
|
240
261
|
|
241
262
|
|
242
263
|
def select_experts(
|
@@ -59,20 +59,20 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
|
59
59
|
)
|
60
60
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
61
61
|
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
62
|
-
from sglang.srt.layers.quantization.modelopt_quant import
|
62
|
+
from sglang.srt.layers.quantization.modelopt_quant import (
|
63
|
+
ModelOptFp4Config,
|
64
|
+
ModelOptFp8Config,
|
65
|
+
)
|
63
66
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
64
67
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
65
68
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
66
|
-
from sglang.srt.layers.vocab_parallel_embedding import (
|
67
|
-
ParallelLMHead,
|
68
|
-
UnquantizedEmbeddingMethod,
|
69
|
-
)
|
70
69
|
|
71
70
|
# Base quantization methods that don't depend on vllm
|
72
71
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
73
72
|
"fp8": Fp8Config,
|
74
73
|
"blockwise_int8": BlockInt8Config,
|
75
74
|
"modelopt": ModelOptFp8Config,
|
75
|
+
"modelopt_fp4": ModelOptFp4Config,
|
76
76
|
"w8a8_int8": W8A8Int8Config,
|
77
77
|
"w8a8_fp8": W8A8Fp8Config,
|
78
78
|
"moe_wna16": MoeWNA16Config,
|
@@ -176,6 +176,13 @@ def get_linear_quant_method(
|
|
176
176
|
prefix: str,
|
177
177
|
linear_method_cls: type,
|
178
178
|
):
|
179
|
+
# Move import here to avoid circular import. This is only used in monkey patching
|
180
|
+
# of vllm's QuantizationConfig.
|
181
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
182
|
+
ParallelLMHead,
|
183
|
+
UnquantizedEmbeddingMethod,
|
184
|
+
)
|
185
|
+
|
179
186
|
cloned_config = deepcopy(config)
|
180
187
|
parallel_lm_head_quantized = (
|
181
188
|
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
@@ -77,6 +77,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
77
77
|
sparsity_ignore_list: List[str],
|
78
78
|
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
79
79
|
config: Optional[Dict[str, Any]] = None,
|
80
|
+
packed_modules_mapping: Dict[str, List[str]] = {},
|
80
81
|
):
|
81
82
|
super().__init__()
|
82
83
|
self.ignore = ignore
|
@@ -87,6 +88,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
87
88
|
self.sparsity_scheme_map = sparsity_scheme_map
|
88
89
|
self.sparsity_ignore_list = sparsity_ignore_list
|
89
90
|
self.config = config
|
91
|
+
self.packed_modules_mapping = packed_modules_mapping
|
90
92
|
|
91
93
|
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
92
94
|
return CompressedTensorsLinearMethod(self)
|
@@ -136,6 +138,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
136
138
|
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
137
139
|
config=config
|
138
140
|
)
|
141
|
+
packed_modules_mapping = config.get("packed_modules_mapping", {})
|
139
142
|
|
140
143
|
return cls(
|
141
144
|
target_scheme_map=target_scheme_map,
|
@@ -144,6 +147,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
144
147
|
sparsity_scheme_map=sparsity_scheme_map,
|
145
148
|
sparsity_ignore_list=sparsity_ignore_list,
|
146
149
|
config=config,
|
150
|
+
packed_modules_mapping=packed_modules_mapping,
|
147
151
|
)
|
148
152
|
|
149
153
|
@classmethod
|