sglang 0.4.0.post2__py3-none-any.whl → 0.4.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +0 -12
- sglang/bench_one_batch.py +0 -12
- sglang/bench_serving.py +11 -2
- sglang/lang/backend/openai.py +10 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -0
- sglang/srt/layers/attention/flashinfer_backend.py +49 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +124 -99
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
- sglang/srt/layers/moe/topk.py +205 -0
- sglang/srt/layers/quantization/__init__.py +3 -3
- sglang/srt/layers/quantization/fp8.py +169 -32
- sglang/srt/layers/quantization/fp8_kernel.py +292 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/torchao_utils.py +11 -15
- sglang/srt/managers/schedule_batch.py +16 -10
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +13 -16
- sglang/srt/managers/tokenizer_manager.py +130 -111
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_loader/loader.py +22 -11
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +19 -0
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +23 -0
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_params.py +9 -2
- sglang/srt/server.py +21 -37
- sglang/srt/utils.py +33 -44
- sglang/test/test_block_fp8.py +341 -0
- sglang/version.py +1 -1
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/METADATA +4 -4
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/RECORD +52 -48
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,292 @@
|
|
1
|
+
# Copyright 2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
from typing import List, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import triton
|
19
|
+
import triton.language as tl
|
20
|
+
|
21
|
+
|
22
|
+
@triton.jit
|
23
|
+
def _per_token_group_quant_fp8(
|
24
|
+
# Pointers to inputs and output
|
25
|
+
y_ptr,
|
26
|
+
y_q_ptr,
|
27
|
+
y_s_ptr,
|
28
|
+
# Stride of input
|
29
|
+
y_stride,
|
30
|
+
# Collums of input
|
31
|
+
N,
|
32
|
+
# Avoid to divide zero
|
33
|
+
eps,
|
34
|
+
# Information for float8
|
35
|
+
fp8_min,
|
36
|
+
fp8_max,
|
37
|
+
# Meta-parameters
|
38
|
+
BLOCK: tl.constexpr,
|
39
|
+
):
|
40
|
+
"""A Triton-accelerated function to perform per-token-group quantization on a
|
41
|
+
tensor.
|
42
|
+
|
43
|
+
This function converts the tensor values into float8 values.
|
44
|
+
"""
|
45
|
+
# Map the program id to the row of X and Y it should compute.
|
46
|
+
g_id = tl.program_id(0)
|
47
|
+
y_ptr += g_id * y_stride
|
48
|
+
y_q_ptr += g_id * y_stride
|
49
|
+
y_s_ptr += g_id
|
50
|
+
|
51
|
+
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
52
|
+
mask = cols < N
|
53
|
+
|
54
|
+
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
55
|
+
# Quant
|
56
|
+
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
57
|
+
y_s = _absmax / fp8_max
|
58
|
+
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
59
|
+
|
60
|
+
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
61
|
+
tl.store(y_s_ptr, y_s)
|
62
|
+
|
63
|
+
|
64
|
+
def per_token_group_quant_fp8(
|
65
|
+
x: torch.Tensor,
|
66
|
+
group_size: int,
|
67
|
+
eps: float = 1e-10,
|
68
|
+
dtype: torch.dtype = torch.float8_e4m3fn,
|
69
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
70
|
+
"""Function to perform per-token-group quantization on an input tensor `x`.
|
71
|
+
|
72
|
+
It converts the tensor values into signed float8 values and returns the
|
73
|
+
quantized tensor along with the scaling factor used for quantization.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
x: The input tenosr with ndim >= 2.
|
77
|
+
group_size: The group size used for quantization.
|
78
|
+
eps: The minimum to avoid dividing zero.
|
79
|
+
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
83
|
+
"""
|
84
|
+
assert (
|
85
|
+
x.shape[-1] % group_size == 0
|
86
|
+
), "the last dimension of `x` cannot be divisible by `group_size`"
|
87
|
+
assert x.is_contiguous(), "`x` is not contiguous"
|
88
|
+
|
89
|
+
finfo = torch.finfo(dtype)
|
90
|
+
fp8_min = finfo.min
|
91
|
+
fp8_max = finfo.max
|
92
|
+
|
93
|
+
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
94
|
+
M = x.numel() // group_size
|
95
|
+
N = group_size
|
96
|
+
x_s = torch.empty(
|
97
|
+
x.shape[:-1] + (x.shape[-1] // group_size,),
|
98
|
+
device=x.device,
|
99
|
+
dtype=torch.float32,
|
100
|
+
)
|
101
|
+
|
102
|
+
BLOCK = triton.next_power_of_2(N)
|
103
|
+
# heuristics for number of warps
|
104
|
+
num_warps = min(max(BLOCK // 256, 1), 8)
|
105
|
+
num_stages = 1
|
106
|
+
_per_token_group_quant_fp8[(M,)](
|
107
|
+
x,
|
108
|
+
x_q,
|
109
|
+
x_s,
|
110
|
+
group_size,
|
111
|
+
N,
|
112
|
+
eps,
|
113
|
+
fp8_min=fp8_min,
|
114
|
+
fp8_max=fp8_max,
|
115
|
+
BLOCK=BLOCK,
|
116
|
+
num_warps=num_warps,
|
117
|
+
num_stages=num_stages,
|
118
|
+
)
|
119
|
+
|
120
|
+
return x_q, x_s
|
121
|
+
|
122
|
+
|
123
|
+
@triton.jit
|
124
|
+
def _w8a8_block_fp8_matmul(
|
125
|
+
# Pointers to inputs and output
|
126
|
+
A,
|
127
|
+
B,
|
128
|
+
C,
|
129
|
+
As,
|
130
|
+
Bs,
|
131
|
+
# Shape for matmul
|
132
|
+
M,
|
133
|
+
N,
|
134
|
+
K,
|
135
|
+
# Block size for block-wise quantization
|
136
|
+
group_n,
|
137
|
+
group_k,
|
138
|
+
# Stride for inputs and output
|
139
|
+
stride_am,
|
140
|
+
stride_ak,
|
141
|
+
stride_bk,
|
142
|
+
stride_bn,
|
143
|
+
stride_cm,
|
144
|
+
stride_cn,
|
145
|
+
stride_As_m,
|
146
|
+
stride_As_k,
|
147
|
+
stride_Bs_k,
|
148
|
+
stride_Bs_n,
|
149
|
+
# Meta-parameters
|
150
|
+
BLOCK_SIZE_M: tl.constexpr,
|
151
|
+
BLOCK_SIZE_N: tl.constexpr,
|
152
|
+
BLOCK_SIZE_K: tl.constexpr,
|
153
|
+
GROUP_SIZE_M: tl.constexpr,
|
154
|
+
):
|
155
|
+
"""Triton-accelerated function used to perform linear operations (dot
|
156
|
+
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
|
157
|
+
tensor `C`.
|
158
|
+
"""
|
159
|
+
|
160
|
+
pid = tl.program_id(axis=0)
|
161
|
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
162
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
163
|
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
164
|
+
group_id = pid // num_pid_in_group
|
165
|
+
first_pid_m = group_id * GROUP_SIZE_M
|
166
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
167
|
+
pid_m = first_pid_m + (pid % group_size_m)
|
168
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
169
|
+
|
170
|
+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
171
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
172
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
173
|
+
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
174
|
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
175
|
+
|
176
|
+
As_ptrs = As + offs_am * stride_As_m
|
177
|
+
offs_bsn = offs_bn // group_n
|
178
|
+
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
179
|
+
|
180
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
181
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
182
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
183
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
184
|
+
|
185
|
+
k_start = k * BLOCK_SIZE_K
|
186
|
+
offs_ks = k_start // group_k
|
187
|
+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
188
|
+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
189
|
+
|
190
|
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
191
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
192
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
193
|
+
|
194
|
+
if C.dtype.element_ty == tl.bfloat16:
|
195
|
+
c = accumulator.to(tl.bfloat16)
|
196
|
+
elif C.dtype.element_ty == tl.float16:
|
197
|
+
c = accumulator.to(tl.float16)
|
198
|
+
else:
|
199
|
+
c = accumulator.to(tl.float32)
|
200
|
+
|
201
|
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
202
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
203
|
+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
204
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
205
|
+
tl.store(c_ptrs, c, mask=c_mask)
|
206
|
+
|
207
|
+
|
208
|
+
def w8a8_block_fp8_matmul(
|
209
|
+
A: torch.Tensor,
|
210
|
+
B: torch.Tensor,
|
211
|
+
As: torch.Tensor,
|
212
|
+
Bs: torch.Tensor,
|
213
|
+
block_size: List[int],
|
214
|
+
output_dtype: torch.dtype = torch.float16,
|
215
|
+
) -> torch.Tensor:
|
216
|
+
"""This function performs matrix multiplication with block-wise quantization.
|
217
|
+
|
218
|
+
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
219
|
+
The output is returned in the specified `output_dtype`.
|
220
|
+
|
221
|
+
Args:
|
222
|
+
A: The input tensor, e.g., activation.
|
223
|
+
B: The input tensor, e.g., weight.
|
224
|
+
As: The per-token-group quantization scale for `A`.
|
225
|
+
Bs: The per-block quantization scale for `B`.
|
226
|
+
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
|
227
|
+
output_dytpe: The dtype of the returned tensor.
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
torch.Tensor: The result of matmul.
|
231
|
+
"""
|
232
|
+
assert len(block_size) == 2
|
233
|
+
block_n, block_k = block_size[0], block_size[1]
|
234
|
+
|
235
|
+
assert A.shape[-1] == B.shape[-1]
|
236
|
+
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
237
|
+
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
238
|
+
M = A.numel() // A.shape[-1]
|
239
|
+
|
240
|
+
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
241
|
+
N, K = B.shape
|
242
|
+
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
243
|
+
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
244
|
+
|
245
|
+
C_shape = A.shape[:-1] + (N,)
|
246
|
+
C = A.new_empty(C_shape, dtype=output_dtype)
|
247
|
+
|
248
|
+
# TODO(HandH1998):
|
249
|
+
# BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized.
|
250
|
+
# BLOCK_SIZE_K must be divisable by block_k
|
251
|
+
# BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements
|
252
|
+
BLOCK_SIZE_M = 128
|
253
|
+
if M < BLOCK_SIZE_M:
|
254
|
+
BLOCK_SIZE_M = triton.next_power_of_2(M)
|
255
|
+
BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
|
256
|
+
BLOCK_SIZE_K = block_k
|
257
|
+
assert block_k % BLOCK_SIZE_K == 0
|
258
|
+
BLOCK_SIZE_N = block_n
|
259
|
+
|
260
|
+
def grid(META):
|
261
|
+
return (
|
262
|
+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
263
|
+
)
|
264
|
+
|
265
|
+
_w8a8_block_fp8_matmul[grid](
|
266
|
+
A,
|
267
|
+
B,
|
268
|
+
C,
|
269
|
+
As,
|
270
|
+
Bs,
|
271
|
+
M,
|
272
|
+
N,
|
273
|
+
K,
|
274
|
+
block_n,
|
275
|
+
block_k,
|
276
|
+
A.stride(-2),
|
277
|
+
A.stride(-1),
|
278
|
+
B.stride(1),
|
279
|
+
B.stride(0),
|
280
|
+
C.stride(-2),
|
281
|
+
C.stride(-1),
|
282
|
+
As.stride(-2),
|
283
|
+
As.stride(-1),
|
284
|
+
Bs.stride(1),
|
285
|
+
Bs.stride(0),
|
286
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
287
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
288
|
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
289
|
+
GROUP_SIZE_M=8,
|
290
|
+
)
|
291
|
+
|
292
|
+
return C
|
@@ -1,6 +1,12 @@
|
|
1
|
-
from typing import Optional, Tuple
|
1
|
+
from typing import List, Optional, Tuple
|
2
2
|
|
3
3
|
import torch
|
4
|
+
from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
|
5
|
+
|
6
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
7
|
+
per_token_group_quant_fp8,
|
8
|
+
w8a8_block_fp8_matmul,
|
9
|
+
)
|
4
10
|
|
5
11
|
|
6
12
|
def normalize_e4m3fn_to_e4m3fnuz(
|
@@ -25,3 +31,86 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|
25
31
|
if input_scale is not None:
|
26
32
|
input_scale = input_scale * 2.0
|
27
33
|
return weight, weight_scale, input_scale
|
34
|
+
|
35
|
+
|
36
|
+
def apply_w8a8_block_fp8_linear(
|
37
|
+
input: torch.Tensor,
|
38
|
+
weight: torch.Tensor,
|
39
|
+
block_size: List[int],
|
40
|
+
weight_scale: torch.Tensor,
|
41
|
+
input_scale: Optional[torch.Tensor] = None,
|
42
|
+
bias: Optional[torch.Tensor] = None,
|
43
|
+
) -> torch.Tensor:
|
44
|
+
assert input_scale is None
|
45
|
+
# View input as 2D matrix for fp8 methods
|
46
|
+
input_2d = input.view(-1, input.shape[-1])
|
47
|
+
output_shape = [*input.shape[:-1], weight.shape[0]]
|
48
|
+
|
49
|
+
q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1])
|
50
|
+
output = w8a8_block_fp8_matmul(
|
51
|
+
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
|
52
|
+
)
|
53
|
+
|
54
|
+
if bias is not None:
|
55
|
+
output = output + bias
|
56
|
+
return output.to(dtype=input.dtype).view(*output_shape)
|
57
|
+
|
58
|
+
|
59
|
+
def input_to_float8(
|
60
|
+
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
61
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
62
|
+
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
63
|
+
finfo = torch.finfo(dtype)
|
64
|
+
min_val, max_val = x.aminmax()
|
65
|
+
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
66
|
+
scale = finfo.max / amax
|
67
|
+
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
68
|
+
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
69
|
+
|
70
|
+
|
71
|
+
def block_quant_to_tensor_quant(
|
72
|
+
x_q_block: torch.Tensor,
|
73
|
+
x_s: torch.Tensor,
|
74
|
+
block_size: List[int],
|
75
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
76
|
+
"""This function converts block-wise quantization to tensor-wise quantization.
|
77
|
+
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
|
78
|
+
and the block size.
|
79
|
+
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
|
80
|
+
Note only float8 is supported for now.
|
81
|
+
"""
|
82
|
+
block_n, block_k = block_size[0], block_size[1]
|
83
|
+
n, k = x_q_block.shape
|
84
|
+
n_tiles = (n + block_n - 1) // block_n
|
85
|
+
k_tiles = (k + block_k - 1) // block_k
|
86
|
+
assert n_tiles == x_s.shape[0]
|
87
|
+
assert k_tiles == x_s.shape[1]
|
88
|
+
|
89
|
+
x_dq_block = x_q_block.to(torch.float32)
|
90
|
+
|
91
|
+
x_dq_block_tiles = [
|
92
|
+
[
|
93
|
+
x_dq_block[
|
94
|
+
j * block_n : min((j + 1) * block_n, n),
|
95
|
+
i * block_k : min((i + 1) * block_k, k),
|
96
|
+
]
|
97
|
+
for i in range(k_tiles)
|
98
|
+
]
|
99
|
+
for j in range(n_tiles)
|
100
|
+
]
|
101
|
+
|
102
|
+
for i in range(k_tiles):
|
103
|
+
for j in range(n_tiles):
|
104
|
+
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
105
|
+
|
106
|
+
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
107
|
+
return x_q_tensor, scale
|
108
|
+
|
109
|
+
|
110
|
+
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
111
|
+
"""
|
112
|
+
Parameter class for weight scales loaded for weights with
|
113
|
+
block-wise quantization. Uses both column and row parallelism.
|
114
|
+
"""
|
115
|
+
|
116
|
+
pass
|
@@ -2,8 +2,14 @@
|
|
2
2
|
Common utilities for torchao.
|
3
3
|
"""
|
4
4
|
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
import pwd
|
8
|
+
|
5
9
|
import torch
|
6
10
|
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
7
13
|
|
8
14
|
def apply_torchao_config_to_model(
|
9
15
|
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
@@ -50,27 +56,17 @@ def apply_torchao_config_to_model(
|
|
50
56
|
elif "gemlite" in torchao_config:
|
51
57
|
# gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
|
52
58
|
# gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
|
53
|
-
import
|
54
|
-
import
|
55
|
-
|
56
|
-
import gemlite
|
57
|
-
from gemlite.core import GemLiteLinearTriton, set_autotune
|
58
|
-
|
59
|
-
try:
|
60
|
-
from torchao.quantization import gemlite_uintx_weight_only
|
61
|
-
except:
|
62
|
-
print(
|
63
|
-
f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization"
|
64
|
-
)
|
65
|
-
return model
|
59
|
+
from gemlite.core import GemLiteLinearTriton
|
60
|
+
from torchao.quantization import gemlite_uintx_weight_only
|
66
61
|
|
67
62
|
_quant_args = torchao_config.split("-")
|
68
63
|
bit_width = int(_quant_args[-2])
|
69
64
|
group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
|
65
|
+
|
70
66
|
try:
|
71
67
|
packing_bitwidth = int(_quant_args[-3])
|
72
|
-
except:
|
73
|
-
# if only 2 inputs found, use default value
|
68
|
+
except (ValueError, IndexError):
|
69
|
+
# if only 2 inputs found or conversion fails, use default value
|
74
70
|
packing_bitwidth = 32
|
75
71
|
|
76
72
|
quantize_(
|
@@ -479,8 +479,22 @@ class Req:
|
|
479
479
|
|
480
480
|
return True
|
481
481
|
|
482
|
+
def reset_for_retract(self):
|
483
|
+
self.prefix_indices = []
|
484
|
+
self.last_node = None
|
485
|
+
self.extend_input_len = 0
|
486
|
+
self.is_retracted = True
|
487
|
+
|
488
|
+
# For incremental logprobs
|
489
|
+
# TODO: Fix the `logprob_start_len`
|
490
|
+
self.last_update_decode_tokens = 0
|
491
|
+
self.logprob_start_len = 10**9
|
492
|
+
|
482
493
|
def __repr__(self):
|
483
|
-
return
|
494
|
+
return (
|
495
|
+
f"rid(n={self.rid}, "
|
496
|
+
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
|
497
|
+
)
|
484
498
|
|
485
499
|
|
486
500
|
bid = 0
|
@@ -894,15 +908,7 @@ class ScheduleBatch:
|
|
894
908
|
)
|
895
909
|
residual_size = max(0, residual_size)
|
896
910
|
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
897
|
-
|
898
|
-
req.prefix_indices = []
|
899
|
-
req.last_node = None
|
900
|
-
req.extend_input_len = 0
|
901
|
-
req.is_retracted = True
|
902
|
-
|
903
|
-
# For incremental logprobs
|
904
|
-
req.last_update_decode_tokens = 0
|
905
|
-
req.logprob_start_len = 10**9
|
911
|
+
req.reset_for_retract()
|
906
912
|
|
907
913
|
self.filter_batch(keep_indices=sorted_indices)
|
908
914
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -22,7 +22,7 @@ import warnings
|
|
22
22
|
from collections import deque
|
23
23
|
from concurrent import futures
|
24
24
|
from types import SimpleNamespace
|
25
|
-
from typing import List, Optional
|
25
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
26
26
|
|
27
27
|
import psutil
|
28
28
|
import setproctitle
|
@@ -260,7 +260,7 @@ class Scheduler:
|
|
260
260
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
261
261
|
|
262
262
|
# Session info
|
263
|
-
self.sessions = {}
|
263
|
+
self.sessions: Dict[str, Session] = {}
|
264
264
|
|
265
265
|
# Init chunked prefill
|
266
266
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
@@ -468,9 +468,6 @@ class Scheduler:
|
|
468
468
|
self.send_to_tokenizer.send_pyobj(
|
469
469
|
UpdateWeightFromDiskReqOutput(success, message)
|
470
470
|
)
|
471
|
-
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
472
|
-
parameter = self.get_weights_by_name(recv_req)
|
473
|
-
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
474
471
|
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
|
475
472
|
success, message = self.init_weights_update_group(recv_req)
|
476
473
|
self.send_to_tokenizer.send_pyobj(
|
@@ -565,7 +562,7 @@ class Scheduler:
|
|
565
562
|
|
566
563
|
if req.logprob_start_len == -1:
|
567
564
|
# By default, only return the logprobs for output tokens
|
568
|
-
req.logprob_start_len = len(
|
565
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
569
566
|
|
570
567
|
# Truncate prompts that are too long
|
571
568
|
if len(req.origin_input_ids) > self.max_req_input_len:
|
@@ -589,12 +586,15 @@ class Scheduler:
|
|
589
586
|
if (
|
590
587
|
req.sampling_params.json_schema is not None
|
591
588
|
or req.sampling_params.regex is not None
|
589
|
+
or req.sampling_params.ebnf is not None
|
592
590
|
):
|
593
591
|
assert self.grammar_backend is not None
|
594
592
|
if req.sampling_params.json_schema is not None:
|
595
593
|
key = ("json", req.sampling_params.json_schema)
|
596
594
|
elif req.sampling_params.regex is not None:
|
597
595
|
key = ("regex", req.sampling_params.regex)
|
596
|
+
elif req.sampling_params.ebnf is not None:
|
597
|
+
key = ("ebnf", req.sampling_params.ebnf)
|
598
598
|
|
599
599
|
req.grammar = self.grammar_backend.get_cached_value(key)
|
600
600
|
if not req.grammar:
|
@@ -629,16 +629,13 @@ class Scheduler:
|
|
629
629
|
self.waiting_queue.append(req)
|
630
630
|
|
631
631
|
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
)
|
640
|
-
else:
|
641
|
-
tree_cache_hit_rate = 0.0
|
632
|
+
self.tree_cache_metrics["total"] += (
|
633
|
+
adder.log_input_tokens + adder.log_hit_tokens
|
634
|
+
) / 10**9
|
635
|
+
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
636
|
+
tree_cache_hit_rate = (
|
637
|
+
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
638
|
+
)
|
642
639
|
|
643
640
|
num_used = self.max_total_num_tokens - (
|
644
641
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|