sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +1 -0
- sglang/bench_serving.py +9 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +51 -5
- sglang/srt/layers/attention/triton_backend.py +16 -25
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
- sglang/srt/layers/moe/topk.py +191 -0
- sglang/srt/layers/quantization/__init__.py +5 -50
- sglang/srt/layers/quantization/fp8.py +221 -36
- sglang/srt/layers/quantization/fp8_kernel.py +278 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/radix_attention.py +8 -1
- sglang/srt/layers/sampler.py +27 -5
- sglang/srt/layers/torchao_utils.py +31 -0
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +54 -34
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +171 -136
- sglang/srt/managers/tokenizer_manager.py +184 -133
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -11
- sglang/srt/model_executor/model_runner.py +28 -14
- sglang/srt/model_parallel.py +66 -5
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +73 -9
- sglang/srt/models/llama.py +22 -0
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +8 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/server.py +2 -1
- sglang/srt/server_args.py +19 -9
- sglang/srt/utils.py +40 -54
- sglang/test/test_block_fp8.py +341 -0
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,278 @@
|
|
1
|
+
from typing import List, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def _per_token_group_quant_fp8(
|
10
|
+
# Pointers to inputs and output
|
11
|
+
y_ptr,
|
12
|
+
y_q_ptr,
|
13
|
+
y_s_ptr,
|
14
|
+
# Stride of input
|
15
|
+
y_stride,
|
16
|
+
# Collums of input
|
17
|
+
N,
|
18
|
+
# Avoid to divide zero
|
19
|
+
eps,
|
20
|
+
# Information for float8
|
21
|
+
fp8_min,
|
22
|
+
fp8_max,
|
23
|
+
# Meta-parameters
|
24
|
+
BLOCK: tl.constexpr,
|
25
|
+
):
|
26
|
+
"""A Triton-accelerated function to perform per-token-group quantization on a
|
27
|
+
tensor.
|
28
|
+
|
29
|
+
This function converts the tensor values into float8 values.
|
30
|
+
"""
|
31
|
+
# Map the program id to the row of X and Y it should compute.
|
32
|
+
g_id = tl.program_id(0)
|
33
|
+
y_ptr += g_id * y_stride
|
34
|
+
y_q_ptr += g_id * y_stride
|
35
|
+
y_s_ptr += g_id
|
36
|
+
|
37
|
+
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
38
|
+
mask = cols < N
|
39
|
+
|
40
|
+
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
41
|
+
# Quant
|
42
|
+
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
43
|
+
y_s = _absmax / fp8_max
|
44
|
+
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
45
|
+
|
46
|
+
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
47
|
+
tl.store(y_s_ptr, y_s)
|
48
|
+
|
49
|
+
|
50
|
+
def per_token_group_quant_fp8(
|
51
|
+
x: torch.Tensor,
|
52
|
+
group_size: int,
|
53
|
+
eps: float = 1e-10,
|
54
|
+
dtype: torch.dtype = torch.float8_e4m3fn,
|
55
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
56
|
+
"""Function to perform per-token-group quantization on an input tensor `x`.
|
57
|
+
|
58
|
+
It converts the tensor values into signed float8 values and returns the
|
59
|
+
quantized tensor along with the scaling factor used for quantization.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
x: The input tenosr with ndim >= 2.
|
63
|
+
group_size: The group size used for quantization.
|
64
|
+
eps: The minimum to avoid dividing zero.
|
65
|
+
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
69
|
+
"""
|
70
|
+
assert (
|
71
|
+
x.shape[-1] % group_size == 0
|
72
|
+
), "the last dimension of `x` cannot be divisible by `group_size`"
|
73
|
+
assert x.is_contiguous(), "`x` is not contiguous"
|
74
|
+
|
75
|
+
finfo = torch.finfo(dtype)
|
76
|
+
fp8_min = finfo.min
|
77
|
+
fp8_max = finfo.max
|
78
|
+
|
79
|
+
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
80
|
+
M = x.numel() // group_size
|
81
|
+
N = group_size
|
82
|
+
x_s = torch.empty(
|
83
|
+
x.shape[:-1] + (x.shape[-1] // group_size,),
|
84
|
+
device=x.device,
|
85
|
+
dtype=torch.float32,
|
86
|
+
)
|
87
|
+
|
88
|
+
BLOCK = triton.next_power_of_2(N)
|
89
|
+
# heuristics for number of warps
|
90
|
+
num_warps = min(max(BLOCK // 256, 1), 8)
|
91
|
+
num_stages = 1
|
92
|
+
_per_token_group_quant_fp8[(M,)](
|
93
|
+
x,
|
94
|
+
x_q,
|
95
|
+
x_s,
|
96
|
+
group_size,
|
97
|
+
N,
|
98
|
+
eps,
|
99
|
+
fp8_min=fp8_min,
|
100
|
+
fp8_max=fp8_max,
|
101
|
+
BLOCK=BLOCK,
|
102
|
+
num_warps=num_warps,
|
103
|
+
num_stages=num_stages,
|
104
|
+
)
|
105
|
+
|
106
|
+
return x_q, x_s
|
107
|
+
|
108
|
+
|
109
|
+
@triton.jit
|
110
|
+
def _w8a8_block_fp8_matmul(
|
111
|
+
# Pointers to inputs and output
|
112
|
+
A,
|
113
|
+
B,
|
114
|
+
C,
|
115
|
+
As,
|
116
|
+
Bs,
|
117
|
+
# Shape for matmul
|
118
|
+
M,
|
119
|
+
N,
|
120
|
+
K,
|
121
|
+
# Block size for block-wise quantization
|
122
|
+
group_n,
|
123
|
+
group_k,
|
124
|
+
# Stride for inputs and output
|
125
|
+
stride_am,
|
126
|
+
stride_ak,
|
127
|
+
stride_bk,
|
128
|
+
stride_bn,
|
129
|
+
stride_cm,
|
130
|
+
stride_cn,
|
131
|
+
stride_As_m,
|
132
|
+
stride_As_k,
|
133
|
+
stride_Bs_k,
|
134
|
+
stride_Bs_n,
|
135
|
+
# Meta-parameters
|
136
|
+
BLOCK_SIZE_M: tl.constexpr,
|
137
|
+
BLOCK_SIZE_N: tl.constexpr,
|
138
|
+
BLOCK_SIZE_K: tl.constexpr,
|
139
|
+
GROUP_SIZE_M: tl.constexpr,
|
140
|
+
):
|
141
|
+
"""Triton-accelerated function used to perform linear operations (dot
|
142
|
+
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
|
143
|
+
tensor `C`.
|
144
|
+
"""
|
145
|
+
|
146
|
+
pid = tl.program_id(axis=0)
|
147
|
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
148
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
149
|
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
150
|
+
group_id = pid // num_pid_in_group
|
151
|
+
first_pid_m = group_id * GROUP_SIZE_M
|
152
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
153
|
+
pid_m = first_pid_m + (pid % group_size_m)
|
154
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
155
|
+
|
156
|
+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
157
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
158
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
159
|
+
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
160
|
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
161
|
+
|
162
|
+
As_ptrs = As + offs_am * stride_As_m
|
163
|
+
offs_bsn = offs_bn // group_n
|
164
|
+
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
165
|
+
|
166
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
167
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
168
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
169
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
170
|
+
|
171
|
+
k_start = k * BLOCK_SIZE_K
|
172
|
+
offs_ks = k_start // group_k
|
173
|
+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
174
|
+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
175
|
+
|
176
|
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
177
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
178
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
179
|
+
|
180
|
+
if C.dtype.element_ty == tl.bfloat16:
|
181
|
+
c = accumulator.to(tl.bfloat16)
|
182
|
+
elif C.dtype.element_ty == tl.float16:
|
183
|
+
c = accumulator.to(tl.float16)
|
184
|
+
else:
|
185
|
+
c = accumulator.to(tl.float32)
|
186
|
+
|
187
|
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
188
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
189
|
+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
190
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
191
|
+
tl.store(c_ptrs, c, mask=c_mask)
|
192
|
+
|
193
|
+
|
194
|
+
def w8a8_block_fp8_matmul(
|
195
|
+
A: torch.Tensor,
|
196
|
+
B: torch.Tensor,
|
197
|
+
As: torch.Tensor,
|
198
|
+
Bs: torch.Tensor,
|
199
|
+
block_size: List[int],
|
200
|
+
output_dtype: torch.dtype = torch.float16,
|
201
|
+
) -> torch.Tensor:
|
202
|
+
"""This function performs matrix multiplication with block-wise quantization.
|
203
|
+
|
204
|
+
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
205
|
+
The output is returned in the specified `output_dtype`.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
A: The input tensor, e.g., activation.
|
209
|
+
B: The input tensor, e.g., weight.
|
210
|
+
As: The per-token-group quantization scale for `A`.
|
211
|
+
Bs: The per-block quantization scale for `B`.
|
212
|
+
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
|
213
|
+
output_dytpe: The dtype of the returned tensor.
|
214
|
+
|
215
|
+
Returns:
|
216
|
+
torch.Tensor: The result of matmul.
|
217
|
+
"""
|
218
|
+
assert len(block_size) == 2
|
219
|
+
block_n, block_k = block_size[0], block_size[1]
|
220
|
+
|
221
|
+
assert A.shape[-1] == B.shape[-1]
|
222
|
+
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
223
|
+
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
224
|
+
M = A.numel() // A.shape[-1]
|
225
|
+
|
226
|
+
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
227
|
+
N, K = B.shape
|
228
|
+
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
229
|
+
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
230
|
+
|
231
|
+
C_shape = A.shape[:-1] + (N,)
|
232
|
+
C = A.new_empty(C_shape, dtype=output_dtype)
|
233
|
+
|
234
|
+
# TODO(HandH1998):
|
235
|
+
# BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized.
|
236
|
+
# BLOCK_SIZE_K must be divisable by block_k
|
237
|
+
# BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements
|
238
|
+
BLOCK_SIZE_M = 128
|
239
|
+
if M < BLOCK_SIZE_M:
|
240
|
+
BLOCK_SIZE_M = triton.next_power_of_2(M)
|
241
|
+
BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
|
242
|
+
BLOCK_SIZE_K = block_k
|
243
|
+
assert block_k % BLOCK_SIZE_K == 0
|
244
|
+
BLOCK_SIZE_N = block_n
|
245
|
+
|
246
|
+
def grid(META):
|
247
|
+
return (
|
248
|
+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
249
|
+
)
|
250
|
+
|
251
|
+
_w8a8_block_fp8_matmul[grid](
|
252
|
+
A,
|
253
|
+
B,
|
254
|
+
C,
|
255
|
+
As,
|
256
|
+
Bs,
|
257
|
+
M,
|
258
|
+
N,
|
259
|
+
K,
|
260
|
+
block_n,
|
261
|
+
block_k,
|
262
|
+
A.stride(-2),
|
263
|
+
A.stride(-1),
|
264
|
+
B.stride(1),
|
265
|
+
B.stride(0),
|
266
|
+
C.stride(-2),
|
267
|
+
C.stride(-1),
|
268
|
+
As.stride(-2),
|
269
|
+
As.stride(-1),
|
270
|
+
Bs.stride(1),
|
271
|
+
Bs.stride(0),
|
272
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
273
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
274
|
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
275
|
+
GROUP_SIZE_M=8,
|
276
|
+
)
|
277
|
+
|
278
|
+
return C
|
@@ -1,6 +1,12 @@
|
|
1
|
-
from typing import Optional, Tuple
|
1
|
+
from typing import List, Optional, Tuple
|
2
2
|
|
3
3
|
import torch
|
4
|
+
from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
|
5
|
+
|
6
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
7
|
+
per_token_group_quant_fp8,
|
8
|
+
w8a8_block_fp8_matmul,
|
9
|
+
)
|
4
10
|
|
5
11
|
|
6
12
|
def normalize_e4m3fn_to_e4m3fnuz(
|
@@ -25,3 +31,86 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|
25
31
|
if input_scale is not None:
|
26
32
|
input_scale = input_scale * 2.0
|
27
33
|
return weight, weight_scale, input_scale
|
34
|
+
|
35
|
+
|
36
|
+
def apply_w8a8_block_fp8_linear(
|
37
|
+
input: torch.Tensor,
|
38
|
+
weight: torch.Tensor,
|
39
|
+
block_size: List[int],
|
40
|
+
weight_scale: torch.Tensor,
|
41
|
+
input_scale: Optional[torch.Tensor] = None,
|
42
|
+
bias: Optional[torch.Tensor] = None,
|
43
|
+
) -> torch.Tensor:
|
44
|
+
assert input_scale is None
|
45
|
+
# View input as 2D matrix for fp8 methods
|
46
|
+
input_2d = input.view(-1, input.shape[-1])
|
47
|
+
output_shape = [*input.shape[:-1], weight.shape[0]]
|
48
|
+
|
49
|
+
q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1])
|
50
|
+
output = w8a8_block_fp8_matmul(
|
51
|
+
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
|
52
|
+
)
|
53
|
+
|
54
|
+
if bias is not None:
|
55
|
+
output = output + bias
|
56
|
+
return output.to(dtype=input.dtype).view(*output_shape)
|
57
|
+
|
58
|
+
|
59
|
+
def input_to_float8(
|
60
|
+
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
61
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
62
|
+
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
63
|
+
finfo = torch.finfo(dtype)
|
64
|
+
min_val, max_val = x.aminmax()
|
65
|
+
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
66
|
+
scale = finfo.max / amax
|
67
|
+
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
68
|
+
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
69
|
+
|
70
|
+
|
71
|
+
def block_quant_to_tensor_quant(
|
72
|
+
x_q_block: torch.Tensor,
|
73
|
+
x_s: torch.Tensor,
|
74
|
+
block_size: List[int],
|
75
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
76
|
+
"""This function converts block-wise quantization to tensor-wise quantization.
|
77
|
+
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
|
78
|
+
and the block size.
|
79
|
+
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
|
80
|
+
Note only float8 is supported for now.
|
81
|
+
"""
|
82
|
+
block_n, block_k = block_size[0], block_size[1]
|
83
|
+
n, k = x_q_block.shape
|
84
|
+
n_tiles = (n + block_n - 1) // block_n
|
85
|
+
k_tiles = (k + block_k - 1) // block_k
|
86
|
+
assert n_tiles == x_s.shape[0]
|
87
|
+
assert k_tiles == x_s.shape[1]
|
88
|
+
|
89
|
+
x_dq_block = x_q_block.to(torch.float32)
|
90
|
+
|
91
|
+
x_dq_block_tiles = [
|
92
|
+
[
|
93
|
+
x_dq_block[
|
94
|
+
j * block_n : min((j + 1) * block_n, n),
|
95
|
+
i * block_k : min((i + 1) * block_k, k),
|
96
|
+
]
|
97
|
+
for i in range(k_tiles)
|
98
|
+
]
|
99
|
+
for j in range(n_tiles)
|
100
|
+
]
|
101
|
+
|
102
|
+
for i in range(k_tiles):
|
103
|
+
for j in range(n_tiles):
|
104
|
+
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
105
|
+
|
106
|
+
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
107
|
+
return x_q_tensor, scale
|
108
|
+
|
109
|
+
|
110
|
+
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
111
|
+
"""
|
112
|
+
Parameter class for weight scales loaded for weights with
|
113
|
+
block-wise quantization. Uses both column and row parallelism.
|
114
|
+
"""
|
115
|
+
|
116
|
+
pass
|
@@ -48,7 +48,14 @@ class RadixAttention(nn.Module):
|
|
48
48
|
self.sliding_window_size = sliding_window_size or -1
|
49
49
|
self.is_cross_attention = is_cross_attention
|
50
50
|
|
51
|
-
def forward(
|
51
|
+
def forward(
|
52
|
+
self,
|
53
|
+
q,
|
54
|
+
k,
|
55
|
+
v,
|
56
|
+
forward_batch: ForwardBatch,
|
57
|
+
save_kv_cache: bool = True,
|
58
|
+
):
|
52
59
|
if k is not None:
|
53
60
|
# For cross-layer sharing, kv can be None
|
54
61
|
assert v is not None
|
sglang/srt/layers/sampler.py
CHANGED
@@ -51,7 +51,6 @@ class Sampler(nn.Module):
|
|
51
51
|
# Post process logits
|
52
52
|
logits.div_(sampling_info.temperatures)
|
53
53
|
probs = torch.softmax(logits, dim=-1)
|
54
|
-
logits = None
|
55
54
|
del logits
|
56
55
|
|
57
56
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
@@ -84,6 +83,7 @@ class Sampler(nn.Module):
|
|
84
83
|
sampling_info.top_ks,
|
85
84
|
sampling_info.top_ps,
|
86
85
|
sampling_info.min_ps,
|
86
|
+
sampling_info.need_min_p_sampling,
|
87
87
|
)
|
88
88
|
else:
|
89
89
|
raise ValueError(
|
@@ -98,20 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
98
98
|
top_ks: torch.Tensor,
|
99
99
|
top_ps: torch.Tensor,
|
100
100
|
min_ps: torch.Tensor,
|
101
|
+
need_min_p_sampling: bool,
|
101
102
|
):
|
102
103
|
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
103
104
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
104
105
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
105
|
-
min_p_thresholds = probs_sort[:, 0] * min_ps
|
106
|
-
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
107
106
|
probs_sort[
|
108
107
|
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
109
108
|
>= top_ks.view(-1, 1)
|
110
109
|
] = 0.0
|
111
|
-
probs_sort[probs_sort
|
112
|
-
|
110
|
+
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
111
|
+
|
112
|
+
if need_min_p_sampling:
|
113
|
+
min_p_thresholds = probs_sort[:, 0] * min_ps
|
114
|
+
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
115
|
+
|
113
116
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
114
117
|
# int32 range is enough to represent the token ids
|
115
118
|
probs_idx = probs_idx.to(torch.int32)
|
116
119
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
117
120
|
return batch_next_token_ids
|
121
|
+
|
122
|
+
|
123
|
+
def top_p_normalize_probs(
|
124
|
+
probs: torch.Tensor,
|
125
|
+
top_ps: torch.Tensor,
|
126
|
+
):
|
127
|
+
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
128
|
+
return top_p_renorm_prob(probs, top_ps)
|
129
|
+
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
130
|
+
# See also top_k_top_p_min_p_sampling_from_probs_torch
|
131
|
+
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
132
|
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
133
|
+
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
134
|
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
135
|
+
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
136
|
+
else:
|
137
|
+
raise ValueError(
|
138
|
+
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
139
|
+
)
|
@@ -2,8 +2,14 @@
|
|
2
2
|
Common utilities for torchao.
|
3
3
|
"""
|
4
4
|
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
import pwd
|
8
|
+
|
5
9
|
import torch
|
6
10
|
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
7
13
|
|
8
14
|
def apply_torchao_config_to_model(
|
9
15
|
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
@@ -47,6 +53,31 @@ def apply_torchao_config_to_model(
|
|
47
53
|
256,
|
48
54
|
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
|
49
55
|
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
|
56
|
+
elif "gemlite" in torchao_config:
|
57
|
+
# gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
|
58
|
+
# gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
|
59
|
+
from gemlite.core import GemLiteLinearTriton
|
60
|
+
from torchao.quantization import gemlite_uintx_weight_only
|
61
|
+
|
62
|
+
_quant_args = torchao_config.split("-")
|
63
|
+
bit_width = int(_quant_args[-2])
|
64
|
+
group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
|
65
|
+
|
66
|
+
try:
|
67
|
+
packing_bitwidth = int(_quant_args[-3])
|
68
|
+
except (ValueError, IndexError):
|
69
|
+
# if only 2 inputs found or conversion fails, use default value
|
70
|
+
packing_bitwidth = 32
|
71
|
+
|
72
|
+
quantize_(
|
73
|
+
model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth)
|
74
|
+
)
|
75
|
+
|
76
|
+
# try to load gemlite kernel config
|
77
|
+
GemLiteLinearTriton.load_config(
|
78
|
+
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
|
79
|
+
)
|
80
|
+
|
50
81
|
elif "fp8wo" in torchao_config:
|
51
82
|
# this requires newer hardware
|
52
83
|
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
@@ -17,9 +17,10 @@ import dataclasses
|
|
17
17
|
import logging
|
18
18
|
import signal
|
19
19
|
from collections import OrderedDict
|
20
|
-
from typing import List, Union
|
20
|
+
from typing import Dict, List, Union
|
21
21
|
|
22
22
|
import psutil
|
23
|
+
import setproctitle
|
23
24
|
import zmq
|
24
25
|
|
25
26
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
@@ -28,7 +29,6 @@ from sglang.srt.managers.io_struct import (
|
|
28
29
|
BatchStrOut,
|
29
30
|
BatchTokenIDOut,
|
30
31
|
)
|
31
|
-
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
32
32
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
33
33
|
from sglang.srt.utils import configure_logger, get_zmq_socket
|
34
34
|
from sglang.utils import find_printable_text, get_exception_traceback
|
@@ -75,17 +75,25 @@ class DetokenizerManager:
|
|
75
75
|
|
76
76
|
self.decode_status = LimitedCapacityDict()
|
77
77
|
|
78
|
-
def
|
79
|
-
|
78
|
+
def trim_matched_stop(
|
79
|
+
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
80
|
+
):
|
81
|
+
if no_stop_trim or not finished_reason:
|
82
|
+
return output
|
83
|
+
|
84
|
+
matched = finished_reason.get("matched", None)
|
85
|
+
if not matched:
|
80
86
|
return output
|
81
87
|
|
82
|
-
#
|
83
|
-
|
84
|
-
|
88
|
+
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
89
|
+
|
90
|
+
# Trim stop str.
|
91
|
+
if isinstance(matched, str) and isinstance(output, str):
|
92
|
+
pos = output.find(matched)
|
85
93
|
return output[:pos] if pos != -1 else output
|
86
|
-
|
87
|
-
|
88
|
-
):
|
94
|
+
|
95
|
+
# Trim stop token.
|
96
|
+
if isinstance(matched, int) and isinstance(output, list):
|
89
97
|
assert len(output) > 0
|
90
98
|
return output[:-1]
|
91
99
|
return output
|
@@ -124,9 +132,9 @@ class DetokenizerManager:
|
|
124
132
|
s.decode_ids = recv_obj.decode_ids[i]
|
125
133
|
|
126
134
|
read_ids.append(
|
127
|
-
self.
|
135
|
+
self.trim_matched_stop(
|
128
136
|
s.decode_ids[s.surr_offset :],
|
129
|
-
recv_obj.
|
137
|
+
recv_obj.finished_reasons[i],
|
130
138
|
recv_obj.no_stop_trim[i],
|
131
139
|
)
|
132
140
|
)
|
@@ -149,7 +157,7 @@ class DetokenizerManager:
|
|
149
157
|
for i in range(bs):
|
150
158
|
s = self.decode_status[recv_obj.rids[i]]
|
151
159
|
new_text = read_texts[i][len(surr_texts[i]) :]
|
152
|
-
if recv_obj.
|
160
|
+
if recv_obj.finished_reasons[i] is None:
|
153
161
|
# Streaming chunk: update the decode status
|
154
162
|
if len(new_text) > 0 and not new_text.endswith("�"):
|
155
163
|
s.decoded_text = s.decoded_text + new_text
|
@@ -160,9 +168,9 @@ class DetokenizerManager:
|
|
160
168
|
new_text = find_printable_text(new_text)
|
161
169
|
|
162
170
|
output_strs.append(
|
163
|
-
self.
|
171
|
+
self.trim_matched_stop(
|
164
172
|
s.decoded_text + new_text,
|
165
|
-
recv_obj.
|
173
|
+
recv_obj.finished_reasons[i],
|
166
174
|
recv_obj.no_stop_trim[i],
|
167
175
|
)
|
168
176
|
)
|
@@ -170,9 +178,20 @@ class DetokenizerManager:
|
|
170
178
|
self.send_to_tokenizer.send_pyobj(
|
171
179
|
BatchStrOut(
|
172
180
|
rids=recv_obj.rids,
|
181
|
+
finished_reasons=recv_obj.finished_reasons,
|
173
182
|
output_strs=output_strs,
|
174
|
-
|
175
|
-
|
183
|
+
prompt_tokens=recv_obj.prompt_tokens,
|
184
|
+
completion_tokens=recv_obj.completion_tokens,
|
185
|
+
cached_tokens=recv_obj.cached_tokens,
|
186
|
+
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
187
|
+
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
188
|
+
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
189
|
+
output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
|
190
|
+
input_top_logprobs_val=recv_obj.input_top_logprobs_val,
|
191
|
+
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
192
|
+
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
193
|
+
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
194
|
+
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
|
176
195
|
)
|
177
196
|
)
|
178
197
|
|
@@ -194,6 +213,7 @@ def run_detokenizer_process(
|
|
194
213
|
server_args: ServerArgs,
|
195
214
|
port_args: PortArgs,
|
196
215
|
):
|
216
|
+
setproctitle.setproctitle("sglang::detokenizer")
|
197
217
|
configure_logger(server_args)
|
198
218
|
parent_process = psutil.Process().parent()
|
199
219
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput:
|
|
308
308
|
class BatchTokenIDOut:
|
309
309
|
# The request id
|
310
310
|
rids: List[str]
|
311
|
+
# The finish reason
|
312
|
+
finished_reasons: List[BaseFinishReason]
|
313
|
+
# For incremental decoding
|
311
314
|
# The version id to sync decode status with in detokenizer_manager
|
312
315
|
vids: List[int]
|
313
316
|
decoded_texts: List[str]
|
@@ -315,35 +318,61 @@ class BatchTokenIDOut:
|
|
315
318
|
read_offsets: List[int]
|
316
319
|
# Only used when `--skip-tokenizer-init`
|
317
320
|
output_ids: Optional[List[int]]
|
321
|
+
# Detokenization configs
|
318
322
|
skip_special_tokens: List[bool]
|
319
323
|
spaces_between_special_tokens: List[bool]
|
320
|
-
meta_info: List[Dict]
|
321
|
-
finished_reason: List[BaseFinishReason]
|
322
324
|
no_stop_trim: List[bool]
|
325
|
+
# Token counts
|
326
|
+
prompt_tokens: List[int]
|
327
|
+
completion_tokens: List[int]
|
328
|
+
cached_tokens: List[int]
|
329
|
+
# Logprobs
|
330
|
+
input_token_logprobs_val: List[float]
|
331
|
+
input_token_logprobs_idx: List[int]
|
332
|
+
output_token_logprobs_val: List[float]
|
333
|
+
output_token_logprobs_idx: List[int]
|
334
|
+
input_top_logprobs_val: List[List]
|
335
|
+
input_top_logprobs_idx: List[List]
|
336
|
+
output_top_logprobs_val: List[List]
|
337
|
+
output_top_logprobs_idx: List[List]
|
338
|
+
normalized_prompt_logprob: List[float]
|
323
339
|
|
324
340
|
|
325
341
|
@dataclass
|
326
342
|
class BatchStrOut:
|
327
343
|
# The request id
|
328
344
|
rids: List[str]
|
345
|
+
# The finish reason
|
346
|
+
finished_reasons: List[dict]
|
329
347
|
# The output decoded strings
|
330
348
|
output_strs: List[str]
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
349
|
+
|
350
|
+
# Token counts
|
351
|
+
prompt_tokens: List[int]
|
352
|
+
completion_tokens: List[int]
|
353
|
+
cached_tokens: List[int]
|
354
|
+
# Logprobs
|
355
|
+
input_token_logprobs_val: List[float]
|
356
|
+
input_token_logprobs_idx: List[int]
|
357
|
+
output_token_logprobs_val: List[float]
|
358
|
+
output_token_logprobs_idx: List[int]
|
359
|
+
input_top_logprobs_val: List[List]
|
360
|
+
input_top_logprobs_idx: List[List]
|
361
|
+
output_top_logprobs_val: List[List]
|
362
|
+
output_top_logprobs_idx: List[List]
|
363
|
+
normalized_prompt_logprob: List[float]
|
335
364
|
|
336
365
|
|
337
366
|
@dataclass
|
338
367
|
class BatchEmbeddingOut:
|
339
368
|
# The request id
|
340
369
|
rids: List[str]
|
370
|
+
# The finish reason
|
371
|
+
finished_reasons: List[BaseFinishReason]
|
341
372
|
# The output embedding
|
342
373
|
embeddings: List[List[float]]
|
343
|
-
#
|
344
|
-
|
345
|
-
# The finish reason
|
346
|
-
finished_reason: List[BaseFinishReason]
|
374
|
+
# Token counts
|
375
|
+
prompt_tokens: List[int]
|
347
376
|
|
348
377
|
|
349
378
|
@dataclass
|