sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/constrained/outlines_backend.py +9 -1
- sglang/srt/custom_op.py +40 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/layers/activation.py +10 -5
- sglang/srt/layers/attention/flashinfer_backend.py +284 -39
- sglang/srt/layers/attention/triton_backend.py +71 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
- sglang/srt/layers/layernorm.py +1 -5
- sglang/srt/layers/moe/ep_moe/layer.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
- sglang/srt/layers/moe/topk.py +4 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8_kernel.py +140 -2
- sglang/srt/layers/rotary_embedding.py +1 -3
- sglang/srt/layers/sampler.py +4 -4
- sglang/srt/lora/backend/__init__.py +8 -0
- sglang/srt/lora/backend/base_backend.py +95 -0
- sglang/srt/lora/backend/flashinfer_backend.py +91 -0
- sglang/srt/lora/backend/triton_backend.py +61 -0
- sglang/srt/lora/lora.py +127 -112
- sglang/srt/lora/lora_manager.py +50 -18
- sglang/srt/lora/triton_ops/__init__.py +5 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
- sglang/srt/model_executor/cuda_graph_runner.py +77 -80
- sglang/srt/model_executor/forward_batch_info.py +58 -59
- sglang/srt/model_executor/model_runner.py +2 -2
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/server_args.py +13 -2
- sglang/srt/speculative/build_eagle_tree.py +4 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
- sglang/srt/speculative/eagle_utils.py +361 -372
- sglang/srt/speculative/eagle_worker.py +177 -45
- sglang/srt/utils.py +7 -0
- sglang/test/runners.py +2 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +15 -6
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +72 -33
- sglang/srt/layers/custom_op_util.py +0 -25
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,164 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 64,
|
4
|
+
"BLOCK_SIZE_N": 16,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 4,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 2,
|
9
|
+
"waves_per_eu": 0
|
10
|
+
},
|
11
|
+
"2": {
|
12
|
+
"BLOCK_SIZE_M": 64,
|
13
|
+
"BLOCK_SIZE_N": 16,
|
14
|
+
"BLOCK_SIZE_K": 128,
|
15
|
+
"GROUP_SIZE_M": 32,
|
16
|
+
"num_warps": 4,
|
17
|
+
"num_stages": 2,
|
18
|
+
"waves_per_eu": 0
|
19
|
+
},
|
20
|
+
"4": {
|
21
|
+
"BLOCK_SIZE_M": 64,
|
22
|
+
"BLOCK_SIZE_N": 16,
|
23
|
+
"BLOCK_SIZE_K": 128,
|
24
|
+
"GROUP_SIZE_M": 1,
|
25
|
+
"num_warps": 4,
|
26
|
+
"num_stages": 2,
|
27
|
+
"waves_per_eu": 0
|
28
|
+
},
|
29
|
+
"8": {
|
30
|
+
"BLOCK_SIZE_M": 64,
|
31
|
+
"BLOCK_SIZE_N": 16,
|
32
|
+
"BLOCK_SIZE_K": 128,
|
33
|
+
"GROUP_SIZE_M": 4,
|
34
|
+
"num_warps": 4,
|
35
|
+
"num_stages": 2,
|
36
|
+
"waves_per_eu": 0
|
37
|
+
},
|
38
|
+
"16": {
|
39
|
+
"BLOCK_SIZE_M": 64,
|
40
|
+
"BLOCK_SIZE_N": 16,
|
41
|
+
"BLOCK_SIZE_K": 128,
|
42
|
+
"GROUP_SIZE_M": 16,
|
43
|
+
"num_warps": 4,
|
44
|
+
"num_stages": 2,
|
45
|
+
"waves_per_eu": 0
|
46
|
+
},
|
47
|
+
"24": {
|
48
|
+
"BLOCK_SIZE_M": 64,
|
49
|
+
"BLOCK_SIZE_N": 16,
|
50
|
+
"BLOCK_SIZE_K": 128,
|
51
|
+
"GROUP_SIZE_M": 16,
|
52
|
+
"num_warps": 4,
|
53
|
+
"num_stages": 2,
|
54
|
+
"waves_per_eu": 0
|
55
|
+
},
|
56
|
+
"32": {
|
57
|
+
"BLOCK_SIZE_M": 64,
|
58
|
+
"BLOCK_SIZE_N": 16,
|
59
|
+
"BLOCK_SIZE_K": 128,
|
60
|
+
"GROUP_SIZE_M": 16,
|
61
|
+
"num_warps": 4,
|
62
|
+
"num_stages": 2,
|
63
|
+
"waves_per_eu": 0
|
64
|
+
},
|
65
|
+
"48": {
|
66
|
+
"BLOCK_SIZE_M": 64,
|
67
|
+
"BLOCK_SIZE_N": 16,
|
68
|
+
"BLOCK_SIZE_K": 128,
|
69
|
+
"GROUP_SIZE_M": 16,
|
70
|
+
"num_warps": 4,
|
71
|
+
"num_stages": 2,
|
72
|
+
"waves_per_eu": 0
|
73
|
+
},
|
74
|
+
"64": {
|
75
|
+
"BLOCK_SIZE_M": 64,
|
76
|
+
"BLOCK_SIZE_N": 16,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 2,
|
81
|
+
"waves_per_eu": 0
|
82
|
+
},
|
83
|
+
"96": {
|
84
|
+
"BLOCK_SIZE_M": 64,
|
85
|
+
"BLOCK_SIZE_N": 16,
|
86
|
+
"BLOCK_SIZE_K": 128,
|
87
|
+
"GROUP_SIZE_M": 1,
|
88
|
+
"num_warps": 4,
|
89
|
+
"num_stages": 2,
|
90
|
+
"waves_per_eu": 0
|
91
|
+
},
|
92
|
+
"128": {
|
93
|
+
"BLOCK_SIZE_M": 64,
|
94
|
+
"BLOCK_SIZE_N": 32,
|
95
|
+
"BLOCK_SIZE_K": 128,
|
96
|
+
"GROUP_SIZE_M": 1,
|
97
|
+
"num_warps": 4,
|
98
|
+
"num_stages": 2,
|
99
|
+
"waves_per_eu": 0
|
100
|
+
},
|
101
|
+
"256": {
|
102
|
+
"BLOCK_SIZE_M": 64,
|
103
|
+
"BLOCK_SIZE_N": 32,
|
104
|
+
"BLOCK_SIZE_K": 128,
|
105
|
+
"GROUP_SIZE_M": 1,
|
106
|
+
"num_warps": 4,
|
107
|
+
"num_stages": 2,
|
108
|
+
"waves_per_eu": 0
|
109
|
+
},
|
110
|
+
"512": {
|
111
|
+
"BLOCK_SIZE_M": 128,
|
112
|
+
"BLOCK_SIZE_N": 32,
|
113
|
+
"BLOCK_SIZE_K": 128,
|
114
|
+
"GROUP_SIZE_M": 32,
|
115
|
+
"num_warps": 4,
|
116
|
+
"num_stages": 2,
|
117
|
+
"waves_per_eu": 0
|
118
|
+
},
|
119
|
+
"1024": {
|
120
|
+
"BLOCK_SIZE_M": 64,
|
121
|
+
"BLOCK_SIZE_N": 64,
|
122
|
+
"BLOCK_SIZE_K": 128,
|
123
|
+
"GROUP_SIZE_M": 4,
|
124
|
+
"num_warps": 4,
|
125
|
+
"num_stages": 2,
|
126
|
+
"waves_per_eu": 0
|
127
|
+
},
|
128
|
+
"1536": {
|
129
|
+
"BLOCK_SIZE_M": 64,
|
130
|
+
"BLOCK_SIZE_N": 64,
|
131
|
+
"BLOCK_SIZE_K": 128,
|
132
|
+
"GROUP_SIZE_M": 1,
|
133
|
+
"num_warps": 4,
|
134
|
+
"num_stages": 2,
|
135
|
+
"waves_per_eu": 0
|
136
|
+
},
|
137
|
+
"2048": {
|
138
|
+
"BLOCK_SIZE_M": 64,
|
139
|
+
"BLOCK_SIZE_N": 128,
|
140
|
+
"BLOCK_SIZE_K": 128,
|
141
|
+
"GROUP_SIZE_M": 1,
|
142
|
+
"num_warps": 4,
|
143
|
+
"num_stages": 2,
|
144
|
+
"waves_per_eu": 0
|
145
|
+
},
|
146
|
+
"3072": {
|
147
|
+
"BLOCK_SIZE_M": 64,
|
148
|
+
"BLOCK_SIZE_N": 128,
|
149
|
+
"BLOCK_SIZE_K": 128,
|
150
|
+
"GROUP_SIZE_M": 4,
|
151
|
+
"num_warps": 4,
|
152
|
+
"num_stages": 2,
|
153
|
+
"waves_per_eu": 0
|
154
|
+
},
|
155
|
+
"4096": {
|
156
|
+
"BLOCK_SIZE_M": 64,
|
157
|
+
"BLOCK_SIZE_N": 128,
|
158
|
+
"BLOCK_SIZE_K": 128,
|
159
|
+
"GROUP_SIZE_M": 1,
|
160
|
+
"num_warps": 4,
|
161
|
+
"num_stages": 2,
|
162
|
+
"waves_per_eu": 0
|
163
|
+
}
|
164
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 32,
|
4
|
+
"BLOCK_SIZE_N": 32,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 8,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 8,
|
16
|
+
"num_stages": 5
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 32,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 32,
|
23
|
+
"num_warps": 8,
|
24
|
+
"num_stages": 2
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 64,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 64,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 5
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 32,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 64,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 2
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 16,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 2
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 64,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 8,
|
56
|
+
"num_stages": 5
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 64,
|
63
|
+
"num_warps": 8,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 64,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 128,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 32,
|
79
|
+
"num_warps": 8,
|
80
|
+
"num_stages": 5
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 64,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 64,
|
87
|
+
"num_warps": 8,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 256,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 8,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 32,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 2
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 2
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 32,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 64,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|
@@ -22,7 +22,7 @@ import torch
|
|
22
22
|
import triton
|
23
23
|
import triton.language as tl
|
24
24
|
|
25
|
-
from sglang.srt.utils import get_device_name, is_hip
|
25
|
+
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
|
26
26
|
|
27
27
|
is_hip_ = is_hip()
|
28
28
|
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
@@ -220,6 +220,132 @@ def _w8a8_block_fp8_matmul(
|
|
220
220
|
tl.store(c_ptrs, c, mask=c_mask)
|
221
221
|
|
222
222
|
|
223
|
+
@triton.jit
|
224
|
+
def _w8a8_block_fp8_matmul_unrolledx4(
|
225
|
+
# Pointers to inputs and output
|
226
|
+
A,
|
227
|
+
B,
|
228
|
+
C,
|
229
|
+
As,
|
230
|
+
Bs,
|
231
|
+
# Shape for matmul
|
232
|
+
M,
|
233
|
+
N,
|
234
|
+
K,
|
235
|
+
# Block size for block-wise quantization
|
236
|
+
group_n,
|
237
|
+
group_k,
|
238
|
+
# Stride for inputs and output
|
239
|
+
stride_am,
|
240
|
+
stride_ak,
|
241
|
+
stride_bk,
|
242
|
+
stride_bn,
|
243
|
+
stride_cm,
|
244
|
+
stride_cn,
|
245
|
+
stride_As_m,
|
246
|
+
stride_As_k,
|
247
|
+
stride_Bs_k,
|
248
|
+
stride_Bs_n,
|
249
|
+
# Meta-parameters
|
250
|
+
BLOCK_SIZE_M: tl.constexpr,
|
251
|
+
BLOCK_SIZE_N: tl.constexpr,
|
252
|
+
BLOCK_SIZE_K: tl.constexpr,
|
253
|
+
GROUP_SIZE_M: tl.constexpr,
|
254
|
+
):
|
255
|
+
"""Triton-accelerated function used to perform linear operations (dot
|
256
|
+
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
|
257
|
+
tensor `C`.
|
258
|
+
"""
|
259
|
+
|
260
|
+
pid = tl.program_id(axis=0)
|
261
|
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
262
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
263
|
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
264
|
+
group_id = pid // num_pid_in_group
|
265
|
+
first_pid_m = group_id * GROUP_SIZE_M
|
266
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
267
|
+
pid_m = first_pid_m + (pid % group_size_m)
|
268
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
269
|
+
|
270
|
+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
271
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
272
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
273
|
+
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
274
|
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
275
|
+
|
276
|
+
As_ptrs = As + offs_am * stride_As_m
|
277
|
+
offs_bsn = offs_bn // group_n
|
278
|
+
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
279
|
+
|
280
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
281
|
+
# manually unroll to 4 iterations
|
282
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K) // 4):
|
283
|
+
# 1st iteration
|
284
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
285
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
286
|
+
|
287
|
+
k_start = k * BLOCK_SIZE_K
|
288
|
+
offs_ks = k_start // group_k
|
289
|
+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
290
|
+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
291
|
+
|
292
|
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
293
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
294
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
295
|
+
|
296
|
+
# 2nd iteration
|
297
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
298
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
299
|
+
|
300
|
+
k_start = k_start + BLOCK_SIZE_K
|
301
|
+
offs_ks = k_start // group_k
|
302
|
+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
303
|
+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
304
|
+
|
305
|
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
306
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
307
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
308
|
+
|
309
|
+
# 3rd iteration
|
310
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
311
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
312
|
+
|
313
|
+
k_start = k_start + BLOCK_SIZE_K
|
314
|
+
offs_ks = k_start // group_k
|
315
|
+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
316
|
+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
317
|
+
|
318
|
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
319
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
320
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
321
|
+
|
322
|
+
# 4th iteration
|
323
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
324
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
325
|
+
|
326
|
+
k_start = k_start + BLOCK_SIZE_K
|
327
|
+
offs_ks = k_start // group_k
|
328
|
+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
329
|
+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
330
|
+
|
331
|
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
332
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
333
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
334
|
+
|
335
|
+
if C.dtype.element_ty == tl.bfloat16:
|
336
|
+
c = accumulator.to(tl.bfloat16)
|
337
|
+
elif C.dtype.element_ty == tl.float16:
|
338
|
+
c = accumulator.to(tl.float16)
|
339
|
+
else:
|
340
|
+
c = accumulator.to(tl.float32)
|
341
|
+
|
342
|
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
343
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
344
|
+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
345
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
346
|
+
tl.store(c_ptrs, c, mask=c_mask)
|
347
|
+
|
348
|
+
|
223
349
|
@functools.lru_cache
|
224
350
|
def get_w8a8_block_fp8_configs(
|
225
351
|
N: int, K: int, block_n: int, block_k: int
|
@@ -324,7 +450,19 @@ def w8a8_block_fp8_matmul(
|
|
324
450
|
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
325
451
|
)
|
326
452
|
|
327
|
-
|
453
|
+
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
454
|
+
# Empirical testing shows the sweet spot lies when it's less than the # of
|
455
|
+
# compute units available on the device.
|
456
|
+
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
457
|
+
N, config["BLOCK_SIZE_N"]
|
458
|
+
)
|
459
|
+
kernel = (
|
460
|
+
_w8a8_block_fp8_matmul_unrolledx4
|
461
|
+
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
462
|
+
else _w8a8_block_fp8_matmul
|
463
|
+
)
|
464
|
+
|
465
|
+
kernel[grid](
|
328
466
|
A,
|
329
467
|
B,
|
330
468
|
C,
|
@@ -7,9 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
7
7
|
import torch
|
8
8
|
import torch.nn as nn
|
9
9
|
from vllm import _custom_ops as ops
|
10
|
-
from vllm.model_executor.custom_op import CustomOp
|
11
10
|
|
12
|
-
from sglang.srt.
|
11
|
+
from sglang.srt.custom_op import CustomOp
|
13
12
|
from sglang.srt.utils import is_cuda_available
|
14
13
|
|
15
14
|
_is_cuda_available = is_cuda_available()
|
@@ -59,7 +58,6 @@ def _apply_rotary_emb(
|
|
59
58
|
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
60
59
|
|
61
60
|
|
62
|
-
@register_custom_op("sglang_rotary_embedding")
|
63
61
|
class RotaryEmbedding(CustomOp):
|
64
62
|
"""Original rotary positional embedding."""
|
65
63
|
|
sglang/srt/layers/sampler.py
CHANGED
@@ -85,7 +85,7 @@ class Sampler(nn.Module):
|
|
85
85
|
if sampling_info.need_min_p_sampling:
|
86
86
|
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
87
87
|
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
88
|
-
batch_next_token_ids
|
88
|
+
batch_next_token_ids = min_p_sampling_from_probs(
|
89
89
|
probs, uniform_samples, sampling_info.min_ps
|
90
90
|
)
|
91
91
|
else:
|
@@ -97,9 +97,9 @@ class Sampler(nn.Module):
|
|
97
97
|
filter_apply_order="joint",
|
98
98
|
)
|
99
99
|
|
100
|
-
|
101
|
-
|
102
|
-
|
100
|
+
if self.use_nan_detectioin and not torch.all(success):
|
101
|
+
logger.warning("Detected errors during sampling!")
|
102
|
+
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
103
103
|
|
104
104
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
105
105
|
# A slower fallback implementation with torch native operations.
|
@@ -0,0 +1,95 @@
|
|
1
|
+
from typing import Tuple, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.lora.lora import LoraBatchInfo
|
6
|
+
|
7
|
+
|
8
|
+
def get_fuse_output_scaling_add_from_name(name: str) -> bool:
|
9
|
+
mapping = {
|
10
|
+
"triton": True,
|
11
|
+
"flashinfer": False,
|
12
|
+
}
|
13
|
+
return mapping.get(name, False)
|
14
|
+
|
15
|
+
|
16
|
+
def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
|
17
|
+
mapping = {
|
18
|
+
"triton": True,
|
19
|
+
"flashinfer": False,
|
20
|
+
}
|
21
|
+
return mapping.get(name, False)
|
22
|
+
|
23
|
+
|
24
|
+
class BaseLoraBackend:
|
25
|
+
"""Base class for different Lora backends.
|
26
|
+
Each backend has its own implementation of Lora kernels.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
name: name of backend
|
30
|
+
batch_info: information of current batch for use
|
31
|
+
fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
|
32
|
+
and the operation of scaling and adding will be fused into kernel
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
36
|
+
self.name = name
|
37
|
+
self.batch_info = batch_info
|
38
|
+
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
|
39
|
+
self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)
|
40
|
+
|
41
|
+
def run_lora_a_sgemm(
|
42
|
+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
43
|
+
) -> torch.Tensor:
|
44
|
+
"""Run segment Gemm of lora a modules with current backend.
|
45
|
+
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
49
|
+
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
|
50
|
+
usually input_dim is much larger than r
|
51
|
+
Returns:
|
52
|
+
result with shape (s, r)
|
53
|
+
"""
|
54
|
+
pass
|
55
|
+
|
56
|
+
def run_lora_b_sgemm(
|
57
|
+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
58
|
+
) -> torch.Tensor:
|
59
|
+
"""Run segment Gemm of lora b modules with current backend.
|
60
|
+
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
|
64
|
+
weights: a set of lora weights with shape (num_lora, output_dim, r)
|
65
|
+
usually output_dim is much larger than r
|
66
|
+
Returns:
|
67
|
+
result with shape (s, output_dim)
|
68
|
+
"""
|
69
|
+
pass
|
70
|
+
|
71
|
+
def run_qkv_lora(
|
72
|
+
self,
|
73
|
+
x: torch.Tensor,
|
74
|
+
qkv_lora_a: torch.Tensor,
|
75
|
+
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
76
|
+
*args,
|
77
|
+
**kwargs
|
78
|
+
) -> torch.Tensor:
|
79
|
+
"""Run the lora pass for QKV Layer.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
83
|
+
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
|
84
|
+
qkv_lora_b: lora_b module for qkv.
|
85
|
+
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
|
86
|
+
If passed in as a tuple of two tensors containing:
|
87
|
+
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
|
88
|
+
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
|
89
|
+
Returns:
|
90
|
+
result with shape (s, output_dim_q + 2 * output_dim_kv)
|
91
|
+
"""
|
92
|
+
pass
|
93
|
+
|
94
|
+
def set_batch_info(self, batch_info: LoraBatchInfo):
|
95
|
+
self.batch_info = batch_info
|