sglang 0.4.0.post2__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +0 -12
- sglang/bench_one_batch.py +0 -12
- sglang/bench_serving.py +1 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/layers/attention/flashinfer_backend.py +49 -5
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +110 -98
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
- sglang/srt/layers/moe/topk.py +191 -0
- sglang/srt/layers/quantization/__init__.py +3 -3
- sglang/srt/layers/quantization/fp8.py +169 -32
- sglang/srt/layers/quantization/fp8_kernel.py +278 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/torchao_utils.py +11 -15
- sglang/srt/managers/schedule_batch.py +16 -10
- sglang/srt/managers/scheduler.py +2 -2
- sglang/srt/managers/tokenizer_manager.py +86 -76
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/model_runner.py +6 -0
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/server.py +1 -0
- sglang/srt/utils.py +33 -44
- sglang/test/test_block_fp8.py +341 -0
- sglang/version.py +1 -1
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/METADATA +3 -3
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/RECORD +44 -40
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,278 @@
|
|
1
|
+
from typing import List, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def _per_token_group_quant_fp8(
|
10
|
+
# Pointers to inputs and output
|
11
|
+
y_ptr,
|
12
|
+
y_q_ptr,
|
13
|
+
y_s_ptr,
|
14
|
+
# Stride of input
|
15
|
+
y_stride,
|
16
|
+
# Collums of input
|
17
|
+
N,
|
18
|
+
# Avoid to divide zero
|
19
|
+
eps,
|
20
|
+
# Information for float8
|
21
|
+
fp8_min,
|
22
|
+
fp8_max,
|
23
|
+
# Meta-parameters
|
24
|
+
BLOCK: tl.constexpr,
|
25
|
+
):
|
26
|
+
"""A Triton-accelerated function to perform per-token-group quantization on a
|
27
|
+
tensor.
|
28
|
+
|
29
|
+
This function converts the tensor values into float8 values.
|
30
|
+
"""
|
31
|
+
# Map the program id to the row of X and Y it should compute.
|
32
|
+
g_id = tl.program_id(0)
|
33
|
+
y_ptr += g_id * y_stride
|
34
|
+
y_q_ptr += g_id * y_stride
|
35
|
+
y_s_ptr += g_id
|
36
|
+
|
37
|
+
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
38
|
+
mask = cols < N
|
39
|
+
|
40
|
+
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
41
|
+
# Quant
|
42
|
+
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
43
|
+
y_s = _absmax / fp8_max
|
44
|
+
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
45
|
+
|
46
|
+
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
47
|
+
tl.store(y_s_ptr, y_s)
|
48
|
+
|
49
|
+
|
50
|
+
def per_token_group_quant_fp8(
|
51
|
+
x: torch.Tensor,
|
52
|
+
group_size: int,
|
53
|
+
eps: float = 1e-10,
|
54
|
+
dtype: torch.dtype = torch.float8_e4m3fn,
|
55
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
56
|
+
"""Function to perform per-token-group quantization on an input tensor `x`.
|
57
|
+
|
58
|
+
It converts the tensor values into signed float8 values and returns the
|
59
|
+
quantized tensor along with the scaling factor used for quantization.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
x: The input tenosr with ndim >= 2.
|
63
|
+
group_size: The group size used for quantization.
|
64
|
+
eps: The minimum to avoid dividing zero.
|
65
|
+
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
69
|
+
"""
|
70
|
+
assert (
|
71
|
+
x.shape[-1] % group_size == 0
|
72
|
+
), "the last dimension of `x` cannot be divisible by `group_size`"
|
73
|
+
assert x.is_contiguous(), "`x` is not contiguous"
|
74
|
+
|
75
|
+
finfo = torch.finfo(dtype)
|
76
|
+
fp8_min = finfo.min
|
77
|
+
fp8_max = finfo.max
|
78
|
+
|
79
|
+
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
80
|
+
M = x.numel() // group_size
|
81
|
+
N = group_size
|
82
|
+
x_s = torch.empty(
|
83
|
+
x.shape[:-1] + (x.shape[-1] // group_size,),
|
84
|
+
device=x.device,
|
85
|
+
dtype=torch.float32,
|
86
|
+
)
|
87
|
+
|
88
|
+
BLOCK = triton.next_power_of_2(N)
|
89
|
+
# heuristics for number of warps
|
90
|
+
num_warps = min(max(BLOCK // 256, 1), 8)
|
91
|
+
num_stages = 1
|
92
|
+
_per_token_group_quant_fp8[(M,)](
|
93
|
+
x,
|
94
|
+
x_q,
|
95
|
+
x_s,
|
96
|
+
group_size,
|
97
|
+
N,
|
98
|
+
eps,
|
99
|
+
fp8_min=fp8_min,
|
100
|
+
fp8_max=fp8_max,
|
101
|
+
BLOCK=BLOCK,
|
102
|
+
num_warps=num_warps,
|
103
|
+
num_stages=num_stages,
|
104
|
+
)
|
105
|
+
|
106
|
+
return x_q, x_s
|
107
|
+
|
108
|
+
|
109
|
+
@triton.jit
|
110
|
+
def _w8a8_block_fp8_matmul(
|
111
|
+
# Pointers to inputs and output
|
112
|
+
A,
|
113
|
+
B,
|
114
|
+
C,
|
115
|
+
As,
|
116
|
+
Bs,
|
117
|
+
# Shape for matmul
|
118
|
+
M,
|
119
|
+
N,
|
120
|
+
K,
|
121
|
+
# Block size for block-wise quantization
|
122
|
+
group_n,
|
123
|
+
group_k,
|
124
|
+
# Stride for inputs and output
|
125
|
+
stride_am,
|
126
|
+
stride_ak,
|
127
|
+
stride_bk,
|
128
|
+
stride_bn,
|
129
|
+
stride_cm,
|
130
|
+
stride_cn,
|
131
|
+
stride_As_m,
|
132
|
+
stride_As_k,
|
133
|
+
stride_Bs_k,
|
134
|
+
stride_Bs_n,
|
135
|
+
# Meta-parameters
|
136
|
+
BLOCK_SIZE_M: tl.constexpr,
|
137
|
+
BLOCK_SIZE_N: tl.constexpr,
|
138
|
+
BLOCK_SIZE_K: tl.constexpr,
|
139
|
+
GROUP_SIZE_M: tl.constexpr,
|
140
|
+
):
|
141
|
+
"""Triton-accelerated function used to perform linear operations (dot
|
142
|
+
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
|
143
|
+
tensor `C`.
|
144
|
+
"""
|
145
|
+
|
146
|
+
pid = tl.program_id(axis=0)
|
147
|
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
148
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
149
|
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
150
|
+
group_id = pid // num_pid_in_group
|
151
|
+
first_pid_m = group_id * GROUP_SIZE_M
|
152
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
153
|
+
pid_m = first_pid_m + (pid % group_size_m)
|
154
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
155
|
+
|
156
|
+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
157
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
158
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
159
|
+
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
160
|
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
161
|
+
|
162
|
+
As_ptrs = As + offs_am * stride_As_m
|
163
|
+
offs_bsn = offs_bn // group_n
|
164
|
+
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
165
|
+
|
166
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
167
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
168
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
169
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
170
|
+
|
171
|
+
k_start = k * BLOCK_SIZE_K
|
172
|
+
offs_ks = k_start // group_k
|
173
|
+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
174
|
+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
175
|
+
|
176
|
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
177
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
178
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
179
|
+
|
180
|
+
if C.dtype.element_ty == tl.bfloat16:
|
181
|
+
c = accumulator.to(tl.bfloat16)
|
182
|
+
elif C.dtype.element_ty == tl.float16:
|
183
|
+
c = accumulator.to(tl.float16)
|
184
|
+
else:
|
185
|
+
c = accumulator.to(tl.float32)
|
186
|
+
|
187
|
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
188
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
189
|
+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
190
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
191
|
+
tl.store(c_ptrs, c, mask=c_mask)
|
192
|
+
|
193
|
+
|
194
|
+
def w8a8_block_fp8_matmul(
|
195
|
+
A: torch.Tensor,
|
196
|
+
B: torch.Tensor,
|
197
|
+
As: torch.Tensor,
|
198
|
+
Bs: torch.Tensor,
|
199
|
+
block_size: List[int],
|
200
|
+
output_dtype: torch.dtype = torch.float16,
|
201
|
+
) -> torch.Tensor:
|
202
|
+
"""This function performs matrix multiplication with block-wise quantization.
|
203
|
+
|
204
|
+
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
205
|
+
The output is returned in the specified `output_dtype`.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
A: The input tensor, e.g., activation.
|
209
|
+
B: The input tensor, e.g., weight.
|
210
|
+
As: The per-token-group quantization scale for `A`.
|
211
|
+
Bs: The per-block quantization scale for `B`.
|
212
|
+
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
|
213
|
+
output_dytpe: The dtype of the returned tensor.
|
214
|
+
|
215
|
+
Returns:
|
216
|
+
torch.Tensor: The result of matmul.
|
217
|
+
"""
|
218
|
+
assert len(block_size) == 2
|
219
|
+
block_n, block_k = block_size[0], block_size[1]
|
220
|
+
|
221
|
+
assert A.shape[-1] == B.shape[-1]
|
222
|
+
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
223
|
+
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
224
|
+
M = A.numel() // A.shape[-1]
|
225
|
+
|
226
|
+
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
227
|
+
N, K = B.shape
|
228
|
+
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
229
|
+
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
230
|
+
|
231
|
+
C_shape = A.shape[:-1] + (N,)
|
232
|
+
C = A.new_empty(C_shape, dtype=output_dtype)
|
233
|
+
|
234
|
+
# TODO(HandH1998):
|
235
|
+
# BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized.
|
236
|
+
# BLOCK_SIZE_K must be divisable by block_k
|
237
|
+
# BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements
|
238
|
+
BLOCK_SIZE_M = 128
|
239
|
+
if M < BLOCK_SIZE_M:
|
240
|
+
BLOCK_SIZE_M = triton.next_power_of_2(M)
|
241
|
+
BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
|
242
|
+
BLOCK_SIZE_K = block_k
|
243
|
+
assert block_k % BLOCK_SIZE_K == 0
|
244
|
+
BLOCK_SIZE_N = block_n
|
245
|
+
|
246
|
+
def grid(META):
|
247
|
+
return (
|
248
|
+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
249
|
+
)
|
250
|
+
|
251
|
+
_w8a8_block_fp8_matmul[grid](
|
252
|
+
A,
|
253
|
+
B,
|
254
|
+
C,
|
255
|
+
As,
|
256
|
+
Bs,
|
257
|
+
M,
|
258
|
+
N,
|
259
|
+
K,
|
260
|
+
block_n,
|
261
|
+
block_k,
|
262
|
+
A.stride(-2),
|
263
|
+
A.stride(-1),
|
264
|
+
B.stride(1),
|
265
|
+
B.stride(0),
|
266
|
+
C.stride(-2),
|
267
|
+
C.stride(-1),
|
268
|
+
As.stride(-2),
|
269
|
+
As.stride(-1),
|
270
|
+
Bs.stride(1),
|
271
|
+
Bs.stride(0),
|
272
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
273
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
274
|
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
275
|
+
GROUP_SIZE_M=8,
|
276
|
+
)
|
277
|
+
|
278
|
+
return C
|
@@ -1,6 +1,12 @@
|
|
1
|
-
from typing import Optional, Tuple
|
1
|
+
from typing import List, Optional, Tuple
|
2
2
|
|
3
3
|
import torch
|
4
|
+
from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
|
5
|
+
|
6
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
7
|
+
per_token_group_quant_fp8,
|
8
|
+
w8a8_block_fp8_matmul,
|
9
|
+
)
|
4
10
|
|
5
11
|
|
6
12
|
def normalize_e4m3fn_to_e4m3fnuz(
|
@@ -25,3 +31,86 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|
25
31
|
if input_scale is not None:
|
26
32
|
input_scale = input_scale * 2.0
|
27
33
|
return weight, weight_scale, input_scale
|
34
|
+
|
35
|
+
|
36
|
+
def apply_w8a8_block_fp8_linear(
|
37
|
+
input: torch.Tensor,
|
38
|
+
weight: torch.Tensor,
|
39
|
+
block_size: List[int],
|
40
|
+
weight_scale: torch.Tensor,
|
41
|
+
input_scale: Optional[torch.Tensor] = None,
|
42
|
+
bias: Optional[torch.Tensor] = None,
|
43
|
+
) -> torch.Tensor:
|
44
|
+
assert input_scale is None
|
45
|
+
# View input as 2D matrix for fp8 methods
|
46
|
+
input_2d = input.view(-1, input.shape[-1])
|
47
|
+
output_shape = [*input.shape[:-1], weight.shape[0]]
|
48
|
+
|
49
|
+
q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1])
|
50
|
+
output = w8a8_block_fp8_matmul(
|
51
|
+
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
|
52
|
+
)
|
53
|
+
|
54
|
+
if bias is not None:
|
55
|
+
output = output + bias
|
56
|
+
return output.to(dtype=input.dtype).view(*output_shape)
|
57
|
+
|
58
|
+
|
59
|
+
def input_to_float8(
|
60
|
+
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
61
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
62
|
+
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
63
|
+
finfo = torch.finfo(dtype)
|
64
|
+
min_val, max_val = x.aminmax()
|
65
|
+
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
66
|
+
scale = finfo.max / amax
|
67
|
+
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
68
|
+
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
69
|
+
|
70
|
+
|
71
|
+
def block_quant_to_tensor_quant(
|
72
|
+
x_q_block: torch.Tensor,
|
73
|
+
x_s: torch.Tensor,
|
74
|
+
block_size: List[int],
|
75
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
76
|
+
"""This function converts block-wise quantization to tensor-wise quantization.
|
77
|
+
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
|
78
|
+
and the block size.
|
79
|
+
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
|
80
|
+
Note only float8 is supported for now.
|
81
|
+
"""
|
82
|
+
block_n, block_k = block_size[0], block_size[1]
|
83
|
+
n, k = x_q_block.shape
|
84
|
+
n_tiles = (n + block_n - 1) // block_n
|
85
|
+
k_tiles = (k + block_k - 1) // block_k
|
86
|
+
assert n_tiles == x_s.shape[0]
|
87
|
+
assert k_tiles == x_s.shape[1]
|
88
|
+
|
89
|
+
x_dq_block = x_q_block.to(torch.float32)
|
90
|
+
|
91
|
+
x_dq_block_tiles = [
|
92
|
+
[
|
93
|
+
x_dq_block[
|
94
|
+
j * block_n : min((j + 1) * block_n, n),
|
95
|
+
i * block_k : min((i + 1) * block_k, k),
|
96
|
+
]
|
97
|
+
for i in range(k_tiles)
|
98
|
+
]
|
99
|
+
for j in range(n_tiles)
|
100
|
+
]
|
101
|
+
|
102
|
+
for i in range(k_tiles):
|
103
|
+
for j in range(n_tiles):
|
104
|
+
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
105
|
+
|
106
|
+
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
107
|
+
return x_q_tensor, scale
|
108
|
+
|
109
|
+
|
110
|
+
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
111
|
+
"""
|
112
|
+
Parameter class for weight scales loaded for weights with
|
113
|
+
block-wise quantization. Uses both column and row parallelism.
|
114
|
+
"""
|
115
|
+
|
116
|
+
pass
|
@@ -2,8 +2,14 @@
|
|
2
2
|
Common utilities for torchao.
|
3
3
|
"""
|
4
4
|
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
import pwd
|
8
|
+
|
5
9
|
import torch
|
6
10
|
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
7
13
|
|
8
14
|
def apply_torchao_config_to_model(
|
9
15
|
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
@@ -50,27 +56,17 @@ def apply_torchao_config_to_model(
|
|
50
56
|
elif "gemlite" in torchao_config:
|
51
57
|
# gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
|
52
58
|
# gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
|
53
|
-
import
|
54
|
-
import
|
55
|
-
|
56
|
-
import gemlite
|
57
|
-
from gemlite.core import GemLiteLinearTriton, set_autotune
|
58
|
-
|
59
|
-
try:
|
60
|
-
from torchao.quantization import gemlite_uintx_weight_only
|
61
|
-
except:
|
62
|
-
print(
|
63
|
-
f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization"
|
64
|
-
)
|
65
|
-
return model
|
59
|
+
from gemlite.core import GemLiteLinearTriton
|
60
|
+
from torchao.quantization import gemlite_uintx_weight_only
|
66
61
|
|
67
62
|
_quant_args = torchao_config.split("-")
|
68
63
|
bit_width = int(_quant_args[-2])
|
69
64
|
group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
|
65
|
+
|
70
66
|
try:
|
71
67
|
packing_bitwidth = int(_quant_args[-3])
|
72
|
-
except:
|
73
|
-
# if only 2 inputs found, use default value
|
68
|
+
except (ValueError, IndexError):
|
69
|
+
# if only 2 inputs found or conversion fails, use default value
|
74
70
|
packing_bitwidth = 32
|
75
71
|
|
76
72
|
quantize_(
|
@@ -479,8 +479,22 @@ class Req:
|
|
479
479
|
|
480
480
|
return True
|
481
481
|
|
482
|
+
def reset_for_retract(self):
|
483
|
+
self.prefix_indices = []
|
484
|
+
self.last_node = None
|
485
|
+
self.extend_input_len = 0
|
486
|
+
self.is_retracted = True
|
487
|
+
|
488
|
+
# For incremental logprobs
|
489
|
+
# TODO: Fix the `logprob_start_len`
|
490
|
+
self.last_update_decode_tokens = 0
|
491
|
+
self.logprob_start_len = 10**9
|
492
|
+
|
482
493
|
def __repr__(self):
|
483
|
-
return
|
494
|
+
return (
|
495
|
+
f"rid(n={self.rid}, "
|
496
|
+
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
|
497
|
+
)
|
484
498
|
|
485
499
|
|
486
500
|
bid = 0
|
@@ -894,15 +908,7 @@ class ScheduleBatch:
|
|
894
908
|
)
|
895
909
|
residual_size = max(0, residual_size)
|
896
910
|
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
897
|
-
|
898
|
-
req.prefix_indices = []
|
899
|
-
req.last_node = None
|
900
|
-
req.extend_input_len = 0
|
901
|
-
req.is_retracted = True
|
902
|
-
|
903
|
-
# For incremental logprobs
|
904
|
-
req.last_update_decode_tokens = 0
|
905
|
-
req.logprob_start_len = 10**9
|
911
|
+
req.reset_for_retract()
|
906
912
|
|
907
913
|
self.filter_batch(keep_indices=sorted_indices)
|
908
914
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -22,7 +22,7 @@ import warnings
|
|
22
22
|
from collections import deque
|
23
23
|
from concurrent import futures
|
24
24
|
from types import SimpleNamespace
|
25
|
-
from typing import List, Optional
|
25
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
26
26
|
|
27
27
|
import psutil
|
28
28
|
import setproctitle
|
@@ -260,7 +260,7 @@ class Scheduler:
|
|
260
260
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
261
261
|
|
262
262
|
# Session info
|
263
|
-
self.sessions = {}
|
263
|
+
self.sessions: Dict[str, Session] = {}
|
264
264
|
|
265
265
|
# Init chunked prefill
|
266
266
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|