sglang 0.4.2.post3__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/check_env.py +1 -0
- sglang/srt/constrained/outlines_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +34 -41
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -3
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/lora/backend/__init__.py +25 -5
- sglang/srt/lora/backend/base_backend.py +31 -9
- sglang/srt/lora/backend/flashinfer_backend.py +41 -4
- sglang/srt/lora/backend/triton_backend.py +34 -4
- sglang/srt/lora/layers.py +293 -0
- sglang/srt/lora/lora.py +101 -326
- sglang/srt/lora/lora_manager.py +101 -269
- sglang/srt/lora/mem_pool.py +174 -0
- sglang/srt/lora/triton_ops/__init__.py +7 -1
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
- sglang/srt/lora/utils.py +141 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +64 -21
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +41 -24
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 256,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 8,
|
8
|
+
"num_stages": 2
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 32,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 64,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 64,
|
22
|
+
"GROUP_SIZE_M": 16,
|
23
|
+
"num_warps": 8,
|
24
|
+
"num_stages": 5
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 32,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 16,
|
31
|
+
"num_warps": 8,
|
32
|
+
"num_stages": 2
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 32,
|
36
|
+
"BLOCK_SIZE_N": 32,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 64,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 64,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 8,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 32,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 2
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 64,
|
60
|
+
"BLOCK_SIZE_N": 32,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 4
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 128,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 64,
|
76
|
+
"BLOCK_SIZE_N": 32,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 64,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 2
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 128,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 64,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 4
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 2
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 2
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 64,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 64,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 64,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 2
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 2
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 2
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 64,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 5
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 64,
|
12
|
+
"BLOCK_SIZE_N": 32,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 5
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 64,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 64,
|
28
|
+
"BLOCK_SIZE_N": 32,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 16,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 64,
|
36
|
+
"BLOCK_SIZE_N": 32,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 64,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 5
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 64,
|
44
|
+
"BLOCK_SIZE_N": 32,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 4
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 64,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 5
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 64,
|
60
|
+
"BLOCK_SIZE_N": 32,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 4
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 64,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 5
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 64,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 64,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 64,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 64,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 64,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 64,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 64,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 64,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 5
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 64,
|
12
|
+
"BLOCK_SIZE_N": 32,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 5
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 64,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 32,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 5
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 64,
|
28
|
+
"BLOCK_SIZE_N": 32,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 64,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 5
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 64,
|
36
|
+
"BLOCK_SIZE_N": 32,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 32,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 64,
|
44
|
+
"BLOCK_SIZE_N": 32,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 5
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 64,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 5
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 64,
|
60
|
+
"BLOCK_SIZE_N": 32,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 5
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 64,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 32,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 5
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 64,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 64,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 64,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 64,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 64,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 64,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -1,8 +1,28 @@
|
|
1
|
-
from .base_backend import
|
2
|
-
from .flashinfer_backend import
|
3
|
-
from .triton_backend import
|
1
|
+
from .base_backend import BaseLoRABackend
|
2
|
+
from .flashinfer_backend import FlashInferLoRABackend
|
3
|
+
from .triton_backend import TritonLoRABackend
|
4
|
+
|
5
|
+
|
6
|
+
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
7
|
+
"""
|
8
|
+
Get corresponding backend class from backend's name
|
9
|
+
"""
|
10
|
+
backend_mapping = {
|
11
|
+
"triton": TritonLoRABackend,
|
12
|
+
"flashinfer": FlashInferLoRABackend,
|
13
|
+
}
|
14
|
+
|
15
|
+
if name in backend_mapping:
|
16
|
+
return backend_mapping[name]
|
17
|
+
|
18
|
+
raise Exception(
|
19
|
+
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
|
20
|
+
)
|
21
|
+
|
4
22
|
|
5
23
|
__all__ = [
|
6
|
-
"
|
7
|
-
"
|
24
|
+
"BaseLoRABackend",
|
25
|
+
"FlashInferLoRABackend",
|
26
|
+
"TritonLoRABackend",
|
27
|
+
"get_backend_from_name",
|
8
28
|
]
|
@@ -2,7 +2,7 @@ from typing import Tuple, Union
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from sglang.srt.lora.
|
5
|
+
from sglang.srt.lora.utils import LoRABatchInfo
|
6
6
|
|
7
7
|
|
8
8
|
def get_fuse_output_scaling_add_from_name(name: str) -> bool:
|
@@ -13,7 +13,7 @@ def get_fuse_output_scaling_add_from_name(name: str) -> bool:
|
|
13
13
|
return mapping.get(name, False)
|
14
14
|
|
15
15
|
|
16
|
-
def
|
16
|
+
def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
|
17
17
|
mapping = {
|
18
18
|
"triton": True,
|
19
19
|
"flashinfer": False,
|
@@ -21,7 +21,7 @@ def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
|
|
21
21
|
return mapping.get(name, False)
|
22
22
|
|
23
23
|
|
24
|
-
class
|
24
|
+
class BaseLoRABackend:
|
25
25
|
"""Base class for different Lora backends.
|
26
26
|
Each backend has its own implementation of Lora kernels.
|
27
27
|
|
@@ -32,11 +32,11 @@ class BaseLoraBackend:
|
|
32
32
|
and the operation of scaling and adding will be fused into kernel
|
33
33
|
"""
|
34
34
|
|
35
|
-
def __init__(self, name: str, batch_info:
|
35
|
+
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
36
36
|
self.name = name
|
37
37
|
self.batch_info = batch_info
|
38
38
|
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
|
39
|
-
self.
|
39
|
+
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
|
40
40
|
|
41
41
|
def run_lora_a_sgemm(
|
42
42
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
@@ -46,10 +46,11 @@ class BaseLoraBackend:
|
|
46
46
|
|
47
47
|
Args:
|
48
48
|
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
49
|
-
weights: a set of lora weights with shape (num_lora, r, input_dim),
|
49
|
+
weights: a set of lora weights with shape (num_lora, c * r, input_dim),
|
50
|
+
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
|
50
51
|
usually input_dim is much larger than r
|
51
52
|
Returns:
|
52
|
-
result with shape (s, r)
|
53
|
+
result with shape (s, c * r)
|
53
54
|
"""
|
54
55
|
pass
|
55
56
|
|
@@ -83,7 +84,7 @@ class BaseLoraBackend:
|
|
83
84
|
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
|
84
85
|
qkv_lora_b: lora_b module for qkv.
|
85
86
|
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
|
86
|
-
If passed in as a tuple of two tensors
|
87
|
+
If passed in as a tuple of two tensors, it should contain:
|
87
88
|
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
|
88
89
|
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
|
89
90
|
Returns:
|
@@ -91,5 +92,26 @@ class BaseLoraBackend:
|
|
91
92
|
"""
|
92
93
|
pass
|
93
94
|
|
94
|
-
def
|
95
|
+
def run_gate_up_lora(
|
96
|
+
self,
|
97
|
+
x: torch.Tensor,
|
98
|
+
gate_up_lora_a: torch.Tensor,
|
99
|
+
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
100
|
+
*args,
|
101
|
+
**kwargs
|
102
|
+
) -> torch.Tensor:
|
103
|
+
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
107
|
+
gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
|
108
|
+
gate_up_lora_b: lora_b module for qkv.
|
109
|
+
If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
|
110
|
+
If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
|
111
|
+
Returns:
|
112
|
+
result with shape (s, 2 * output_dim)
|
113
|
+
"""
|
114
|
+
pass
|
115
|
+
|
116
|
+
def set_batch_info(self, batch_info: LoRABatchInfo):
|
95
117
|
self.batch_info = batch_info
|
@@ -2,17 +2,17 @@ from typing import Tuple
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from sglang.srt.lora.backend import
|
6
|
-
from sglang.srt.lora.
|
5
|
+
from sglang.srt.lora.backend import BaseLoRABackend
|
6
|
+
from sglang.srt.lora.utils import LoRABatchInfo
|
7
7
|
from sglang.srt.utils import is_flashinfer_available
|
8
8
|
|
9
9
|
if is_flashinfer_available():
|
10
10
|
from flashinfer import SegmentGEMMWrapper
|
11
11
|
|
12
12
|
|
13
|
-
class
|
13
|
+
class FlashInferLoRABackend(BaseLoRABackend):
|
14
14
|
|
15
|
-
def __init__(self, name: str, batch_info:
|
15
|
+
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
16
16
|
super().__init__(name, batch_info)
|
17
17
|
|
18
18
|
# Set up SGemm Wrapper from flashinfer
|
@@ -55,6 +55,8 @@ class FlashInferLoraBackend(BaseLoraBackend):
|
|
55
55
|
**kwargs,
|
56
56
|
) -> torch.Tensor:
|
57
57
|
|
58
|
+
assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2
|
59
|
+
|
58
60
|
# Shape of lora_a_output: (s, 3 * r)
|
59
61
|
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
|
60
62
|
|
@@ -89,3 +91,38 @@ class FlashInferLoraBackend(BaseLoraBackend):
|
|
89
91
|
)
|
90
92
|
|
91
93
|
return lora_output
|
94
|
+
|
95
|
+
def run_gate_up_lora(
|
96
|
+
self,
|
97
|
+
x: torch.Tensor,
|
98
|
+
gate_up_lora_a: torch.Tensor,
|
99
|
+
gate_up_lora_b: Tuple[torch.Tensor],
|
100
|
+
*args,
|
101
|
+
**kwargs,
|
102
|
+
) -> torch.Tensor:
|
103
|
+
|
104
|
+
assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2
|
105
|
+
lora_rank = gate_up_lora_b[0].shape[-1]
|
106
|
+
output_dim = gate_up_lora_b[0].shape[-2]
|
107
|
+
|
108
|
+
# Shape of lora_a_output: (s, 2 * r)
|
109
|
+
lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a)
|
110
|
+
|
111
|
+
lora_output = torch.empty(
|
112
|
+
(x.shape[0], 2 * output_dim),
|
113
|
+
device=x.device,
|
114
|
+
dtype=x.dtype,
|
115
|
+
)
|
116
|
+
|
117
|
+
# Compute lora for gate and up proj respectively
|
118
|
+
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
|
119
|
+
x=lora_a_output[:, :lora_rank].contiguous(),
|
120
|
+
weights=gate_up_lora_b[0],
|
121
|
+
)
|
122
|
+
|
123
|
+
lora_output[:, output_dim:] = self.run_lora_b_sgemm(
|
124
|
+
x=lora_a_output[:, lora_rank:].contiguous(),
|
125
|
+
weights=gate_up_lora_b[1],
|
126
|
+
)
|
127
|
+
|
128
|
+
return lora_output
|
@@ -1,17 +1,18 @@
|
|
1
1
|
import torch
|
2
2
|
|
3
|
-
from sglang.srt.lora.backend import
|
4
|
-
from sglang.srt.lora.lora import LoraBatchInfo
|
3
|
+
from sglang.srt.lora.backend import BaseLoRABackend
|
5
4
|
from sglang.srt.lora.triton_ops import (
|
5
|
+
gate_up_lora_b_fwd,
|
6
6
|
qkv_lora_b_fwd,
|
7
7
|
sgemm_lora_a_fwd,
|
8
8
|
sgemm_lora_b_fwd,
|
9
9
|
)
|
10
|
+
from sglang.srt.lora.utils import LoRABatchInfo
|
10
11
|
|
11
12
|
|
12
|
-
class
|
13
|
+
class TritonLoRABackend(BaseLoRABackend):
|
13
14
|
|
14
|
-
def __init__(self, name: str, batch_info:
|
15
|
+
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
15
16
|
super().__init__(name, batch_info)
|
16
17
|
|
17
18
|
def run_lora_a_sgemm(
|
@@ -59,3 +60,32 @@ class TritonLoraBackend(BaseLoraBackend):
|
|
59
60
|
scaling,
|
60
61
|
)
|
61
62
|
return lora_output
|
63
|
+
|
64
|
+
def run_gate_up_lora(
|
65
|
+
self,
|
66
|
+
x: torch.Tensor,
|
67
|
+
gate_up_lora_a: torch.Tensor,
|
68
|
+
gate_up_lora_b: torch.Tensor,
|
69
|
+
base_output: torch.Tensor = None,
|
70
|
+
scaling: float = 1.0,
|
71
|
+
*args,
|
72
|
+
**kwargs
|
73
|
+
) -> torch.Tensor:
|
74
|
+
|
75
|
+
# x: (s, input_dim)
|
76
|
+
# gate_up_lora_a: (num_lora, 2 * r, input_dim)
|
77
|
+
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
|
78
|
+
assert isinstance(gate_up_lora_b, torch.Tensor)
|
79
|
+
output_dim = gate_up_lora_b.shape[-2] // 2
|
80
|
+
|
81
|
+
# lora_a_output: (s, 2 * r)
|
82
|
+
lora_a_output = sgemm_lora_a_fwd(x, gate_up_lora_a, self.batch_info)
|
83
|
+
lora_output = gate_up_lora_b_fwd(
|
84
|
+
lora_a_output,
|
85
|
+
gate_up_lora_b,
|
86
|
+
self.batch_info,
|
87
|
+
output_dim,
|
88
|
+
base_output,
|
89
|
+
scaling,
|
90
|
+
)
|
91
|
+
return lora_output
|