sglang 0.4.2.post4__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/global_config.py +2 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/layers/attention/flashinfer_backend.py +235 -110
- sglang/srt/layers/attention/triton_backend.py +358 -72
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/linear.py +12 -5
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +51 -5
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
- sglang/srt/layers/quantization/fp8_kernel.py +123 -17
- sglang/srt/layers/quantization/fp8_utils.py +33 -4
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +16 -3
- sglang/srt/managers/scheduler.py +29 -0
- sglang/srt/managers/tokenizer_manager.py +6 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
- sglang/srt/model_executor/cuda_graph_runner.py +12 -1
- sglang/srt/model_executor/model_runner.py +12 -2
- sglang/srt/models/deepseek_v2.py +17 -7
- sglang/srt/server_args.py +20 -1
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/METADATA +4 -3
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/RECORD +57 -41
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/top_level.txt +0 -0
@@ -1,61 +1,61 @@
|
|
1
1
|
{
|
2
2
|
"1": {
|
3
|
-
"BLOCK_SIZE_M":
|
4
|
-
"BLOCK_SIZE_N":
|
3
|
+
"BLOCK_SIZE_M": 32,
|
4
|
+
"BLOCK_SIZE_N": 32,
|
5
5
|
"BLOCK_SIZE_K": 128,
|
6
|
-
"GROUP_SIZE_M":
|
6
|
+
"GROUP_SIZE_M": 8,
|
7
7
|
"num_warps": 4,
|
8
8
|
"num_stages": 2,
|
9
9
|
"waves_per_eu": 0
|
10
10
|
},
|
11
11
|
"2": {
|
12
|
-
"BLOCK_SIZE_M":
|
13
|
-
"BLOCK_SIZE_N":
|
14
|
-
"BLOCK_SIZE_K":
|
15
|
-
"GROUP_SIZE_M":
|
12
|
+
"BLOCK_SIZE_M": 32,
|
13
|
+
"BLOCK_SIZE_N": 32,
|
14
|
+
"BLOCK_SIZE_K": 64,
|
15
|
+
"GROUP_SIZE_M": 8,
|
16
16
|
"num_warps": 4,
|
17
17
|
"num_stages": 2,
|
18
18
|
"waves_per_eu": 0
|
19
19
|
},
|
20
20
|
"4": {
|
21
|
-
"BLOCK_SIZE_M":
|
22
|
-
"BLOCK_SIZE_N":
|
21
|
+
"BLOCK_SIZE_M": 32,
|
22
|
+
"BLOCK_SIZE_N": 32,
|
23
23
|
"BLOCK_SIZE_K": 128,
|
24
|
-
"GROUP_SIZE_M":
|
24
|
+
"GROUP_SIZE_M": 32,
|
25
25
|
"num_warps": 4,
|
26
26
|
"num_stages": 2,
|
27
27
|
"waves_per_eu": 0
|
28
28
|
},
|
29
29
|
"8": {
|
30
|
-
"BLOCK_SIZE_M":
|
31
|
-
"BLOCK_SIZE_N":
|
30
|
+
"BLOCK_SIZE_M": 32,
|
31
|
+
"BLOCK_SIZE_N": 64,
|
32
32
|
"BLOCK_SIZE_K": 128,
|
33
|
-
"GROUP_SIZE_M":
|
33
|
+
"GROUP_SIZE_M": 16,
|
34
34
|
"num_warps": 4,
|
35
35
|
"num_stages": 2,
|
36
36
|
"waves_per_eu": 0
|
37
37
|
},
|
38
38
|
"16": {
|
39
|
-
"BLOCK_SIZE_M":
|
40
|
-
"BLOCK_SIZE_N":
|
39
|
+
"BLOCK_SIZE_M": 32,
|
40
|
+
"BLOCK_SIZE_N": 32,
|
41
41
|
"BLOCK_SIZE_K": 128,
|
42
|
-
"GROUP_SIZE_M":
|
42
|
+
"GROUP_SIZE_M": 8,
|
43
43
|
"num_warps": 4,
|
44
44
|
"num_stages": 2,
|
45
45
|
"waves_per_eu": 0
|
46
46
|
},
|
47
47
|
"24": {
|
48
|
-
"BLOCK_SIZE_M":
|
49
|
-
"BLOCK_SIZE_N":
|
48
|
+
"BLOCK_SIZE_M": 32,
|
49
|
+
"BLOCK_SIZE_N": 32,
|
50
50
|
"BLOCK_SIZE_K": 128,
|
51
|
-
"GROUP_SIZE_M":
|
51
|
+
"GROUP_SIZE_M": 8,
|
52
52
|
"num_warps": 4,
|
53
53
|
"num_stages": 2,
|
54
54
|
"waves_per_eu": 0
|
55
55
|
},
|
56
56
|
"32": {
|
57
|
-
"BLOCK_SIZE_M":
|
58
|
-
"BLOCK_SIZE_N":
|
57
|
+
"BLOCK_SIZE_M": 32,
|
58
|
+
"BLOCK_SIZE_N": 32,
|
59
59
|
"BLOCK_SIZE_K": 128,
|
60
60
|
"GROUP_SIZE_M": 16,
|
61
61
|
"num_warps": 4,
|
@@ -64,52 +64,52 @@
|
|
64
64
|
},
|
65
65
|
"48": {
|
66
66
|
"BLOCK_SIZE_M": 64,
|
67
|
-
"BLOCK_SIZE_N":
|
67
|
+
"BLOCK_SIZE_N": 32,
|
68
68
|
"BLOCK_SIZE_K": 128,
|
69
|
-
"GROUP_SIZE_M":
|
69
|
+
"GROUP_SIZE_M": 1,
|
70
70
|
"num_warps": 4,
|
71
71
|
"num_stages": 2,
|
72
72
|
"waves_per_eu": 0
|
73
73
|
},
|
74
74
|
"64": {
|
75
75
|
"BLOCK_SIZE_M": 64,
|
76
|
-
"BLOCK_SIZE_N":
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
77
|
"BLOCK_SIZE_K": 128,
|
78
|
-
"GROUP_SIZE_M":
|
78
|
+
"GROUP_SIZE_M": 4,
|
79
79
|
"num_warps": 4,
|
80
80
|
"num_stages": 2,
|
81
81
|
"waves_per_eu": 0
|
82
82
|
},
|
83
83
|
"96": {
|
84
|
-
"BLOCK_SIZE_M":
|
85
|
-
"BLOCK_SIZE_N":
|
84
|
+
"BLOCK_SIZE_M": 32,
|
85
|
+
"BLOCK_SIZE_N": 128,
|
86
86
|
"BLOCK_SIZE_K": 128,
|
87
|
-
"GROUP_SIZE_M":
|
87
|
+
"GROUP_SIZE_M": 4,
|
88
88
|
"num_warps": 4,
|
89
89
|
"num_stages": 2,
|
90
90
|
"waves_per_eu": 0
|
91
91
|
},
|
92
92
|
"128": {
|
93
|
-
"BLOCK_SIZE_M":
|
93
|
+
"BLOCK_SIZE_M": 128,
|
94
94
|
"BLOCK_SIZE_N": 32,
|
95
95
|
"BLOCK_SIZE_K": 128,
|
96
|
-
"GROUP_SIZE_M":
|
96
|
+
"GROUP_SIZE_M": 16,
|
97
97
|
"num_warps": 4,
|
98
98
|
"num_stages": 2,
|
99
99
|
"waves_per_eu": 0
|
100
100
|
},
|
101
101
|
"256": {
|
102
102
|
"BLOCK_SIZE_M": 64,
|
103
|
-
"BLOCK_SIZE_N":
|
103
|
+
"BLOCK_SIZE_N": 128,
|
104
104
|
"BLOCK_SIZE_K": 128,
|
105
|
-
"GROUP_SIZE_M":
|
105
|
+
"GROUP_SIZE_M": 16,
|
106
106
|
"num_warps": 4,
|
107
107
|
"num_stages": 2,
|
108
108
|
"waves_per_eu": 0
|
109
109
|
},
|
110
110
|
"512": {
|
111
|
-
"BLOCK_SIZE_M":
|
112
|
-
"BLOCK_SIZE_N":
|
111
|
+
"BLOCK_SIZE_M": 64,
|
112
|
+
"BLOCK_SIZE_N": 128,
|
113
113
|
"BLOCK_SIZE_K": 128,
|
114
114
|
"GROUP_SIZE_M": 32,
|
115
115
|
"num_warps": 4,
|
@@ -117,28 +117,28 @@
|
|
117
117
|
"waves_per_eu": 0
|
118
118
|
},
|
119
119
|
"1024": {
|
120
|
-
"BLOCK_SIZE_M":
|
121
|
-
"BLOCK_SIZE_N":
|
120
|
+
"BLOCK_SIZE_M": 32,
|
121
|
+
"BLOCK_SIZE_N": 128,
|
122
122
|
"BLOCK_SIZE_K": 128,
|
123
|
-
"GROUP_SIZE_M":
|
123
|
+
"GROUP_SIZE_M": 8,
|
124
124
|
"num_warps": 4,
|
125
125
|
"num_stages": 2,
|
126
126
|
"waves_per_eu": 0
|
127
127
|
},
|
128
128
|
"1536": {
|
129
129
|
"BLOCK_SIZE_M": 64,
|
130
|
-
"BLOCK_SIZE_N":
|
130
|
+
"BLOCK_SIZE_N": 128,
|
131
131
|
"BLOCK_SIZE_K": 128,
|
132
|
-
"GROUP_SIZE_M":
|
132
|
+
"GROUP_SIZE_M": 4,
|
133
133
|
"num_warps": 4,
|
134
134
|
"num_stages": 2,
|
135
135
|
"waves_per_eu": 0
|
136
136
|
},
|
137
137
|
"2048": {
|
138
|
-
"BLOCK_SIZE_M":
|
138
|
+
"BLOCK_SIZE_M": 32,
|
139
139
|
"BLOCK_SIZE_N": 128,
|
140
140
|
"BLOCK_SIZE_K": 128,
|
141
|
-
"GROUP_SIZE_M":
|
141
|
+
"GROUP_SIZE_M": 4,
|
142
142
|
"num_warps": 4,
|
143
143
|
"num_stages": 2,
|
144
144
|
"waves_per_eu": 0
|
@@ -156,7 +156,7 @@
|
|
156
156
|
"BLOCK_SIZE_M": 64,
|
157
157
|
"BLOCK_SIZE_N": 128,
|
158
158
|
"BLOCK_SIZE_K": 128,
|
159
|
-
"GROUP_SIZE_M":
|
159
|
+
"GROUP_SIZE_M": 4,
|
160
160
|
"num_warps": 4,
|
161
161
|
"num_stages": 2,
|
162
162
|
"waves_per_eu": 0
|
@@ -27,6 +27,10 @@ from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
|
|
27
27
|
is_hip_ = is_hip()
|
28
28
|
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
29
29
|
|
30
|
+
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
31
|
+
if _is_cuda:
|
32
|
+
from sgl_kernel import sgl_per_token_group_quant_fp8
|
33
|
+
|
30
34
|
logger = logging.getLogger(__name__)
|
31
35
|
|
32
36
|
|
@@ -72,11 +76,60 @@ def _per_token_group_quant_fp8(
|
|
72
76
|
tl.store(y_s_ptr, y_s)
|
73
77
|
|
74
78
|
|
79
|
+
@triton.jit
|
80
|
+
def _per_token_group_quant_fp8_colmajor(
|
81
|
+
# Pointers to inputs and output
|
82
|
+
y_ptr,
|
83
|
+
y_q_ptr,
|
84
|
+
y_s_ptr,
|
85
|
+
group_size,
|
86
|
+
# Num columns of y
|
87
|
+
y_num_columns,
|
88
|
+
# Stride from one column to the next of y_s
|
89
|
+
y_s_col_stride,
|
90
|
+
# Avoid to divide zero
|
91
|
+
eps,
|
92
|
+
# Information for float8
|
93
|
+
fp8_min,
|
94
|
+
fp8_max,
|
95
|
+
# Meta-parameters
|
96
|
+
BLOCK: tl.constexpr,
|
97
|
+
):
|
98
|
+
"""A Triton-accelerated function to perform per-token-group
|
99
|
+
quantization on a tensor.
|
100
|
+
This function converts the tensor values into float8 values.
|
101
|
+
"""
|
102
|
+
# Map the program id to the row of X and Y it should compute.
|
103
|
+
g_id = tl.program_id(0)
|
104
|
+
y_ptr += g_id * group_size
|
105
|
+
y_q_ptr += g_id * group_size
|
106
|
+
|
107
|
+
# Convert g_id the flattened block coordinate to 2D so we can index
|
108
|
+
# into the output y_scales matrix
|
109
|
+
blocks_per_row = y_num_columns // group_size
|
110
|
+
scale_col = g_id % blocks_per_row
|
111
|
+
scale_row = g_id // blocks_per_row
|
112
|
+
y_s_ptr += scale_col * y_s_col_stride + scale_row
|
113
|
+
|
114
|
+
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
|
115
|
+
mask = cols < group_size
|
116
|
+
|
117
|
+
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
118
|
+
# Quant
|
119
|
+
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
120
|
+
y_s = _absmax / fp8_max
|
121
|
+
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
122
|
+
|
123
|
+
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
124
|
+
tl.store(y_s_ptr, y_s)
|
125
|
+
|
126
|
+
|
75
127
|
def per_token_group_quant_fp8(
|
76
128
|
x: torch.Tensor,
|
77
129
|
group_size: int,
|
78
130
|
eps: float = 1e-10,
|
79
131
|
dtype: torch.dtype = fp8_type_,
|
132
|
+
column_major_scales: bool = False,
|
80
133
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
81
134
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
82
135
|
|
@@ -108,30 +161,83 @@ def per_token_group_quant_fp8(
|
|
108
161
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
109
162
|
M = x.numel() // group_size
|
110
163
|
N = group_size
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
164
|
+
if column_major_scales:
|
165
|
+
x_s = torch.empty(
|
166
|
+
(x.shape[-1] // group_size,) + x.shape[:-1],
|
167
|
+
device=x.device,
|
168
|
+
dtype=torch.float32,
|
169
|
+
).permute(-1, -2)
|
170
|
+
else:
|
171
|
+
x_s = torch.empty(
|
172
|
+
x.shape[:-1] + (x.shape[-1] // group_size,),
|
173
|
+
device=x.device,
|
174
|
+
dtype=torch.float32,
|
175
|
+
)
|
116
176
|
|
117
177
|
BLOCK = triton.next_power_of_2(N)
|
118
178
|
# heuristics for number of warps
|
119
179
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
120
180
|
num_stages = 1
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
181
|
+
if column_major_scales:
|
182
|
+
_per_token_group_quant_fp8_colmajor[(M,)](
|
183
|
+
x,
|
184
|
+
x_q,
|
185
|
+
x_s,
|
186
|
+
group_size,
|
187
|
+
x.shape[1],
|
188
|
+
x_s.stride(1),
|
189
|
+
eps,
|
190
|
+
fp8_min=fp8_min,
|
191
|
+
fp8_max=fp8_max,
|
192
|
+
BLOCK=BLOCK,
|
193
|
+
num_warps=num_warps,
|
194
|
+
num_stages=num_stages,
|
195
|
+
)
|
196
|
+
else:
|
197
|
+
_per_token_group_quant_fp8[(M,)](
|
198
|
+
x,
|
199
|
+
x_q,
|
200
|
+
x_s,
|
201
|
+
group_size,
|
202
|
+
N,
|
203
|
+
eps,
|
204
|
+
fp8_min=fp8_min,
|
205
|
+
fp8_max=fp8_max,
|
206
|
+
BLOCK=BLOCK,
|
207
|
+
num_warps=num_warps,
|
208
|
+
num_stages=num_stages,
|
209
|
+
)
|
210
|
+
|
211
|
+
return x_q, x_s
|
212
|
+
|
213
|
+
|
214
|
+
def sglang_per_token_group_quant_fp8(
|
215
|
+
x: torch.Tensor,
|
216
|
+
group_size: int,
|
217
|
+
eps: float = 1e-10,
|
218
|
+
dtype: torch.dtype = fp8_type_,
|
219
|
+
):
|
220
|
+
assert (
|
221
|
+
x.shape[-1] % group_size == 0
|
222
|
+
), "the last dimension of `x` cannot be divisible by `group_size`"
|
223
|
+
assert x.is_contiguous(), "`x` is not contiguous"
|
224
|
+
|
225
|
+
finfo = torch.finfo(dtype)
|
226
|
+
fp8_max = finfo.max
|
227
|
+
|
228
|
+
fp8_min = -fp8_max
|
229
|
+
|
230
|
+
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
231
|
+
M = x.numel() // group_size
|
232
|
+
N = group_size
|
233
|
+
x_s = torch.empty(
|
234
|
+
x.shape[:-1] + (x.shape[-1] // group_size,),
|
235
|
+
device=x.device,
|
236
|
+
dtype=torch.float32,
|
133
237
|
)
|
134
238
|
|
239
|
+
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
240
|
+
|
135
241
|
return x_q, x_s
|
136
242
|
|
137
243
|
|
@@ -10,6 +10,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
10
10
|
from sglang.srt.utils import is_hip
|
11
11
|
|
12
12
|
is_hip_ = is_hip()
|
13
|
+
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
14
|
+
if _is_cuda:
|
15
|
+
from sgl_kernel import fp8_blockwise_scaled_mm
|
13
16
|
|
14
17
|
|
15
18
|
def normalize_e4m3fn_to_e4m3fnuz(
|
@@ -36,6 +39,19 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|
36
39
|
return weight, weight_scale, input_scale
|
37
40
|
|
38
41
|
|
42
|
+
def cutlass_block_fp8_supported() -> bool:
|
43
|
+
if _is_cuda:
|
44
|
+
major, minor = torch.cuda.get_device_capability()
|
45
|
+
sm_version = major * 10 + minor
|
46
|
+
cuda_version = tuple(map(int, torch.version.cuda.split(".")))
|
47
|
+
if cuda_version >= (12, 0) and sm_version >= 90:
|
48
|
+
return True
|
49
|
+
return False
|
50
|
+
|
51
|
+
|
52
|
+
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
|
53
|
+
|
54
|
+
|
39
55
|
def apply_w8a8_block_fp8_linear(
|
40
56
|
input: torch.Tensor,
|
41
57
|
weight: torch.Tensor,
|
@@ -48,11 +64,24 @@ def apply_w8a8_block_fp8_linear(
|
|
48
64
|
# View input as 2D matrix for fp8 methods
|
49
65
|
input_2d = input.view(-1, input.shape[-1])
|
50
66
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
|
67
|
+
# TODO: add more robust shape check here
|
68
|
+
shape_supported_by_cutlass = (
|
69
|
+
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
|
55
70
|
)
|
71
|
+
if CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
|
72
|
+
q_input, x_scale = per_token_group_quant_fp8(
|
73
|
+
input_2d, block_size[1], column_major_scales=True
|
74
|
+
)
|
75
|
+
output = fp8_blockwise_scaled_mm(
|
76
|
+
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
77
|
+
)
|
78
|
+
else:
|
79
|
+
q_input, x_scale = per_token_group_quant_fp8(
|
80
|
+
input_2d, block_size[1], column_major_scales=False
|
81
|
+
)
|
82
|
+
output = w8a8_block_fp8_matmul(
|
83
|
+
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
|
84
|
+
)
|
56
85
|
|
57
86
|
if bias is not None:
|
58
87
|
output = output + bias
|
@@ -210,6 +210,7 @@ class DetokenizerManager:
|
|
210
210
|
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
211
211
|
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
212
212
|
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
213
|
+
output_hidden_states=recv_obj.output_hidden_states,
|
213
214
|
)
|
214
215
|
)
|
215
216
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -371,6 +371,8 @@ class BatchTokenIDOut:
|
|
371
371
|
output_top_logprobs_val: List[List]
|
372
372
|
output_top_logprobs_idx: List[List]
|
373
373
|
|
374
|
+
output_hidden_states: List[List[float]]
|
375
|
+
|
374
376
|
|
375
377
|
@dataclass
|
376
378
|
class BatchStrOut:
|
@@ -397,6 +399,8 @@ class BatchStrOut:
|
|
397
399
|
output_top_logprobs_val: List[List]
|
398
400
|
output_top_logprobs_idx: List[List]
|
399
401
|
|
402
|
+
output_hidden_states: List[List[float]]
|
403
|
+
|
400
404
|
|
401
405
|
@dataclass
|
402
406
|
class BatchEmbeddingOut:
|
@@ -65,6 +65,7 @@ global_server_args_dict = {
|
|
65
65
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
66
66
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
67
67
|
"device": ServerArgs.device,
|
68
|
+
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
68
69
|
}
|
69
70
|
|
70
71
|
logger = logging.getLogger(__name__)
|
@@ -315,6 +316,7 @@ class Req:
|
|
315
316
|
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
316
317
|
self.output_top_logprobs_val
|
317
318
|
) = self.output_top_logprobs_idx = None
|
319
|
+
self.hidden_states = []
|
318
320
|
|
319
321
|
# Logprobs (internal values)
|
320
322
|
# The tokens is prefilled but need to be considered as decode tokens
|
@@ -604,6 +606,9 @@ class ScheduleBatch:
|
|
604
606
|
# Enable custom logit processor
|
605
607
|
enable_custom_logit_processor: bool = False
|
606
608
|
|
609
|
+
# Return hidden states
|
610
|
+
return_hidden_states: bool = False
|
611
|
+
|
607
612
|
@classmethod
|
608
613
|
def init_new(
|
609
614
|
cls,
|
@@ -615,6 +620,7 @@ class ScheduleBatch:
|
|
615
620
|
enable_overlap: bool,
|
616
621
|
spec_algorithm: SpeculativeAlgorithm,
|
617
622
|
enable_custom_logit_processor: bool,
|
623
|
+
return_hidden_states: bool = False,
|
618
624
|
):
|
619
625
|
return cls(
|
620
626
|
reqs=reqs,
|
@@ -629,6 +635,7 @@ class ScheduleBatch:
|
|
629
635
|
device=req_to_token_pool.device,
|
630
636
|
spec_algorithm=spec_algorithm,
|
631
637
|
enable_custom_logit_processor=enable_custom_logit_processor,
|
638
|
+
return_hidden_states=return_hidden_states,
|
632
639
|
)
|
633
640
|
|
634
641
|
def batch_size(self):
|
@@ -1196,9 +1203,15 @@ class ScheduleBatch:
|
|
1196
1203
|
spec_algorithm=self.spec_algorithm,
|
1197
1204
|
spec_info=self.spec_info,
|
1198
1205
|
capture_hidden_mode=(
|
1199
|
-
|
1200
|
-
if self.
|
1201
|
-
else
|
1206
|
+
CaptureHiddenMode.FULL
|
1207
|
+
if self.return_hidden_states
|
1208
|
+
else (
|
1209
|
+
getattr(
|
1210
|
+
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
1211
|
+
)
|
1212
|
+
if self.spec_info
|
1213
|
+
else CaptureHiddenMode.NULL
|
1214
|
+
)
|
1202
1215
|
),
|
1203
1216
|
)
|
1204
1217
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -997,6 +997,7 @@ class Scheduler:
|
|
997
997
|
self.enable_overlap,
|
998
998
|
self.spec_algorithm,
|
999
999
|
self.server_args.enable_custom_logit_processor,
|
1000
|
+
self.server_args.return_hidden_states,
|
1000
1001
|
)
|
1001
1002
|
new_batch.prepare_for_extend()
|
1002
1003
|
|
@@ -1156,6 +1157,8 @@ class Scheduler:
|
|
1156
1157
|
logits_output.input_token_logprobs.tolist()
|
1157
1158
|
)
|
1158
1159
|
|
1160
|
+
hidden_state_offset = 0
|
1161
|
+
|
1159
1162
|
# Check finish conditions
|
1160
1163
|
logprob_pt = 0
|
1161
1164
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
@@ -1182,6 +1185,21 @@ class Scheduler:
|
|
1182
1185
|
i, req, logprob_pt, next_token_ids, logits_output
|
1183
1186
|
)
|
1184
1187
|
|
1188
|
+
if (
|
1189
|
+
self.server_args.return_hidden_states
|
1190
|
+
and logits_output.hidden_states is not None
|
1191
|
+
):
|
1192
|
+
req.hidden_states.append(
|
1193
|
+
logits_output.hidden_states[
|
1194
|
+
hidden_state_offset : (
|
1195
|
+
hidden_state_offset := hidden_state_offset
|
1196
|
+
+ len(req.origin_input_ids)
|
1197
|
+
)
|
1198
|
+
]
|
1199
|
+
.cpu()
|
1200
|
+
.clone()
|
1201
|
+
)
|
1202
|
+
|
1185
1203
|
if req.grammar is not None:
|
1186
1204
|
req.grammar.accept_token(next_token_id)
|
1187
1205
|
req.grammar.finished = req.finished()
|
@@ -1275,6 +1293,12 @@ class Scheduler:
|
|
1275
1293
|
logits_output.next_token_top_logprobs_idx[i]
|
1276
1294
|
)
|
1277
1295
|
|
1296
|
+
if (
|
1297
|
+
self.server_args.return_hidden_states
|
1298
|
+
and logits_output.hidden_states is not None
|
1299
|
+
):
|
1300
|
+
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
1301
|
+
|
1278
1302
|
if req.grammar is not None:
|
1279
1303
|
req.grammar.accept_token(next_token_id)
|
1280
1304
|
req.grammar.finished = req.finished()
|
@@ -1398,6 +1422,7 @@ class Scheduler:
|
|
1398
1422
|
completion_tokens = []
|
1399
1423
|
cached_tokens = []
|
1400
1424
|
spec_verify_ct = []
|
1425
|
+
hidden_states = []
|
1401
1426
|
|
1402
1427
|
if return_logprob:
|
1403
1428
|
input_token_logprobs_val = []
|
@@ -1464,6 +1489,8 @@ class Scheduler:
|
|
1464
1489
|
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
1465
1490
|
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
1466
1491
|
|
1492
|
+
hidden_states.append(req.hidden_states)
|
1493
|
+
|
1467
1494
|
# Send to detokenizer
|
1468
1495
|
if rids:
|
1469
1496
|
self.send_to_detokenizer.send_pyobj(
|
@@ -1490,6 +1517,7 @@ class Scheduler:
|
|
1490
1517
|
input_top_logprobs_idx,
|
1491
1518
|
output_top_logprobs_val,
|
1492
1519
|
output_top_logprobs_idx,
|
1520
|
+
hidden_states,
|
1493
1521
|
)
|
1494
1522
|
)
|
1495
1523
|
else: # embedding or reward model
|
@@ -1553,6 +1581,7 @@ class Scheduler:
|
|
1553
1581
|
self.enable_overlap,
|
1554
1582
|
self.spec_algorithm,
|
1555
1583
|
self.server_args.enable_custom_logit_processor,
|
1584
|
+
self.server_args.return_hidden_states,
|
1556
1585
|
)
|
1557
1586
|
idle_batch.prepare_for_idle()
|
1558
1587
|
return idle_batch
|
@@ -796,6 +796,12 @@ class TokenizerManager:
|
|
796
796
|
}
|
797
797
|
)
|
798
798
|
|
799
|
+
if (
|
800
|
+
hasattr(recv_obj, "output_hidden_states")
|
801
|
+
and len(recv_obj.output_hidden_states[i]) > 0
|
802
|
+
):
|
803
|
+
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
804
|
+
|
799
805
|
if isinstance(recv_obj, BatchStrOut):
|
800
806
|
out_dict = {
|
801
807
|
"text": recv_obj.output_strs[i],
|
@@ -156,6 +156,10 @@ class TpModelWorkerClient:
|
|
156
156
|
logits_output.input_token_logprobs = (
|
157
157
|
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
158
158
|
)
|
159
|
+
if logits_output.hidden_states is not None:
|
160
|
+
logits_output.hidden_states = logits_output.hidden_states.to(
|
161
|
+
"cpu", non_blocking=True
|
162
|
+
)
|
159
163
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
160
164
|
copy_done.record()
|
161
165
|
|
@@ -33,6 +33,9 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
33
33
|
ForwardBatch,
|
34
34
|
ForwardMode,
|
35
35
|
)
|
36
|
+
from sglang.srt.utils import is_hip
|
37
|
+
|
38
|
+
is_hip_ = is_hip()
|
36
39
|
|
37
40
|
if TYPE_CHECKING:
|
38
41
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -129,6 +132,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
129
132
|
if bs <= model_runner.req_to_token_pool.size
|
130
133
|
and bs <= server_args.cuda_graph_max_bs
|
131
134
|
]
|
135
|
+
if is_hip_:
|
136
|
+
capture_bs += [i * 8 for i in range(21, 33)]
|
132
137
|
compile_bs = (
|
133
138
|
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
134
139
|
if server_args.enable_torch_compile
|
@@ -349,7 +354,13 @@ class CudaGraphRunner:
|
|
349
354
|
spec_algorithm=self.model_runner.spec_algorithm,
|
350
355
|
spec_info=spec_info,
|
351
356
|
capture_hidden_mode=(
|
352
|
-
|
357
|
+
CaptureHiddenMode.FULL
|
358
|
+
if self.model_runner.server_args.return_hidden_states
|
359
|
+
else (
|
360
|
+
spec_info.capture_hidden_mode
|
361
|
+
if spec_info
|
362
|
+
else CaptureHiddenMode.NULL
|
363
|
+
)
|
353
364
|
),
|
354
365
|
)
|
355
366
|
|