sglang 0.4.2__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/constrained/outlines_backend.py +9 -1
- sglang/srt/custom_op.py +40 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/layers/activation.py +10 -5
- sglang/srt/layers/attention/flashinfer_backend.py +284 -39
- sglang/srt/layers/attention/triton_backend.py +71 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
- sglang/srt/layers/attention/vision.py +243 -40
- sglang/srt/layers/layernorm.py +1 -5
- sglang/srt/layers/moe/ep_moe/layer.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
- sglang/srt/layers/moe/topk.py +4 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +7 -0
- sglang/srt/layers/quantization/fp8_kernel.py +140 -2
- sglang/srt/layers/rotary_embedding.py +29 -15
- sglang/srt/layers/sampler.py +9 -6
- sglang/srt/lora/backend/__init__.py +8 -0
- sglang/srt/lora/backend/base_backend.py +95 -0
- sglang/srt/lora/backend/flashinfer_backend.py +91 -0
- sglang/srt/lora/backend/triton_backend.py +61 -0
- sglang/srt/lora/lora.py +127 -112
- sglang/srt/lora/lora_manager.py +50 -18
- sglang/srt/lora/triton_ops/__init__.py +5 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
- sglang/srt/managers/image_processor.py +77 -38
- sglang/srt/managers/scheduler.py +17 -3
- sglang/srt/mem_cache/base_prefix_cache.py +4 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +30 -1
- sglang/srt/model_executor/cuda_graph_runner.py +77 -80
- sglang/srt/model_executor/forward_batch_info.py +58 -59
- sglang/srt/model_executor/model_runner.py +2 -2
- sglang/srt/models/minicpmv.py +129 -76
- sglang/srt/models/mllama.py +16 -56
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_vl.py +19 -9
- sglang/srt/server_args.py +19 -2
- sglang/srt/speculative/build_eagle_tree.py +4 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
- sglang/srt/speculative/eagle_utils.py +361 -372
- sglang/srt/speculative/eagle_worker.py +177 -45
- sglang/srt/utils.py +7 -2
- sglang/test/runners.py +2 -0
- sglang/utils.py +42 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +16 -7
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +84 -45
- sglang/srt/layers/custom_op_util.py +0 -25
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 32,
|
4
|
+
"BLOCK_SIZE_N": 32,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 8,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 8,
|
16
|
+
"num_stages": 5
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 32,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 32,
|
23
|
+
"num_warps": 8,
|
24
|
+
"num_stages": 2
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 64,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 64,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 5
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 32,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 64,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 2
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 16,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 2
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 64,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 8,
|
56
|
+
"num_stages": 5
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 64,
|
63
|
+
"num_warps": 8,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 64,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 128,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 32,
|
79
|
+
"num_warps": 8,
|
80
|
+
"num_stages": 5
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 64,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 64,
|
87
|
+
"num_warps": 8,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 256,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 8,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 32,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 2
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 2
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 32,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 64,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|
@@ -290,6 +290,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
290
290
|
weight_scale, requires_grad=False
|
291
291
|
)
|
292
292
|
layer.input_scale = None
|
293
|
+
else:
|
294
|
+
layer.weight = torch.nn.Parameter(
|
295
|
+
layer.weight.data, requires_grad=False
|
296
|
+
)
|
297
|
+
layer.weight_scale_inv = torch.nn.Parameter(
|
298
|
+
layer.weight_scale_inv.data, requires_grad=False
|
299
|
+
)
|
293
300
|
return
|
294
301
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
295
302
|
# If checkpoint not serialized fp8, quantize the weights.
|
@@ -22,7 +22,7 @@ import torch
|
|
22
22
|
import triton
|
23
23
|
import triton.language as tl
|
24
24
|
|
25
|
-
from sglang.srt.utils import get_device_name, is_hip
|
25
|
+
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
|
26
26
|
|
27
27
|
is_hip_ = is_hip()
|
28
28
|
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
@@ -220,6 +220,132 @@ def _w8a8_block_fp8_matmul(
|
|
220
220
|
tl.store(c_ptrs, c, mask=c_mask)
|
221
221
|
|
222
222
|
|
223
|
+
@triton.jit
|
224
|
+
def _w8a8_block_fp8_matmul_unrolledx4(
|
225
|
+
# Pointers to inputs and output
|
226
|
+
A,
|
227
|
+
B,
|
228
|
+
C,
|
229
|
+
As,
|
230
|
+
Bs,
|
231
|
+
# Shape for matmul
|
232
|
+
M,
|
233
|
+
N,
|
234
|
+
K,
|
235
|
+
# Block size for block-wise quantization
|
236
|
+
group_n,
|
237
|
+
group_k,
|
238
|
+
# Stride for inputs and output
|
239
|
+
stride_am,
|
240
|
+
stride_ak,
|
241
|
+
stride_bk,
|
242
|
+
stride_bn,
|
243
|
+
stride_cm,
|
244
|
+
stride_cn,
|
245
|
+
stride_As_m,
|
246
|
+
stride_As_k,
|
247
|
+
stride_Bs_k,
|
248
|
+
stride_Bs_n,
|
249
|
+
# Meta-parameters
|
250
|
+
BLOCK_SIZE_M: tl.constexpr,
|
251
|
+
BLOCK_SIZE_N: tl.constexpr,
|
252
|
+
BLOCK_SIZE_K: tl.constexpr,
|
253
|
+
GROUP_SIZE_M: tl.constexpr,
|
254
|
+
):
|
255
|
+
"""Triton-accelerated function used to perform linear operations (dot
|
256
|
+
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
|
257
|
+
tensor `C`.
|
258
|
+
"""
|
259
|
+
|
260
|
+
pid = tl.program_id(axis=0)
|
261
|
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
262
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
263
|
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
264
|
+
group_id = pid // num_pid_in_group
|
265
|
+
first_pid_m = group_id * GROUP_SIZE_M
|
266
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
267
|
+
pid_m = first_pid_m + (pid % group_size_m)
|
268
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
269
|
+
|
270
|
+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
271
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
272
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
273
|
+
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
274
|
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
275
|
+
|
276
|
+
As_ptrs = As + offs_am * stride_As_m
|
277
|
+
offs_bsn = offs_bn // group_n
|
278
|
+
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
279
|
+
|
280
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
281
|
+
# manually unroll to 4 iterations
|
282
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K) // 4):
|
283
|
+
# 1st iteration
|
284
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
285
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
286
|
+
|
287
|
+
k_start = k * BLOCK_SIZE_K
|
288
|
+
offs_ks = k_start // group_k
|
289
|
+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
290
|
+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
291
|
+
|
292
|
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
293
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
294
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
295
|
+
|
296
|
+
# 2nd iteration
|
297
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
298
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
299
|
+
|
300
|
+
k_start = k_start + BLOCK_SIZE_K
|
301
|
+
offs_ks = k_start // group_k
|
302
|
+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
303
|
+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
304
|
+
|
305
|
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
306
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
307
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
308
|
+
|
309
|
+
# 3rd iteration
|
310
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
311
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
312
|
+
|
313
|
+
k_start = k_start + BLOCK_SIZE_K
|
314
|
+
offs_ks = k_start // group_k
|
315
|
+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
316
|
+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
317
|
+
|
318
|
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
319
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
320
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
321
|
+
|
322
|
+
# 4th iteration
|
323
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
324
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
325
|
+
|
326
|
+
k_start = k_start + BLOCK_SIZE_K
|
327
|
+
offs_ks = k_start // group_k
|
328
|
+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
329
|
+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
330
|
+
|
331
|
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
332
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
333
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
334
|
+
|
335
|
+
if C.dtype.element_ty == tl.bfloat16:
|
336
|
+
c = accumulator.to(tl.bfloat16)
|
337
|
+
elif C.dtype.element_ty == tl.float16:
|
338
|
+
c = accumulator.to(tl.float16)
|
339
|
+
else:
|
340
|
+
c = accumulator.to(tl.float32)
|
341
|
+
|
342
|
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
343
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
344
|
+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
345
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
346
|
+
tl.store(c_ptrs, c, mask=c_mask)
|
347
|
+
|
348
|
+
|
223
349
|
@functools.lru_cache
|
224
350
|
def get_w8a8_block_fp8_configs(
|
225
351
|
N: int, K: int, block_n: int, block_k: int
|
@@ -324,7 +450,19 @@ def w8a8_block_fp8_matmul(
|
|
324
450
|
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
325
451
|
)
|
326
452
|
|
327
|
-
|
453
|
+
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
454
|
+
# Empirical testing shows the sweet spot lies when it's less than the # of
|
455
|
+
# compute units available on the device.
|
456
|
+
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
457
|
+
N, config["BLOCK_SIZE_N"]
|
458
|
+
)
|
459
|
+
kernel = (
|
460
|
+
_w8a8_block_fp8_matmul_unrolledx4
|
461
|
+
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
462
|
+
else _w8a8_block_fp8_matmul
|
463
|
+
)
|
464
|
+
|
465
|
+
kernel[grid](
|
328
466
|
A,
|
329
467
|
B,
|
330
468
|
C,
|
@@ -6,9 +6,14 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
6
|
|
7
7
|
import torch
|
8
8
|
import torch.nn as nn
|
9
|
-
from vllm
|
9
|
+
from vllm import _custom_ops as ops
|
10
10
|
|
11
|
-
from sglang.srt.
|
11
|
+
from sglang.srt.custom_op import CustomOp
|
12
|
+
from sglang.srt.utils import is_cuda_available
|
13
|
+
|
14
|
+
_is_cuda_available = is_cuda_available()
|
15
|
+
if _is_cuda_available:
|
16
|
+
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
12
17
|
|
13
18
|
|
14
19
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -53,7 +58,6 @@ def _apply_rotary_emb(
|
|
53
58
|
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
54
59
|
|
55
60
|
|
56
|
-
@register_custom_op("sglang_rotary_embedding")
|
57
61
|
class RotaryEmbedding(CustomOp):
|
58
62
|
"""Original rotary positional embedding."""
|
59
63
|
|
@@ -75,7 +79,9 @@ class RotaryEmbedding(CustomOp):
|
|
75
79
|
self.dtype = dtype
|
76
80
|
|
77
81
|
cache = self._compute_cos_sin_cache()
|
78
|
-
cache
|
82
|
+
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
83
|
+
if not _is_cuda_available:
|
84
|
+
cache = cache.to(dtype)
|
79
85
|
self.cos_sin_cache: torch.Tensor
|
80
86
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
81
87
|
|
@@ -141,17 +147,25 @@ class RotaryEmbedding(CustomOp):
|
|
141
147
|
key: torch.Tensor,
|
142
148
|
offsets: Optional[torch.Tensor] = None,
|
143
149
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
150
|
+
if _is_cuda_available:
|
151
|
+
apply_rope_with_cos_sin_cache_inplace(
|
152
|
+
positions=positions,
|
153
|
+
query=query,
|
154
|
+
key=key,
|
155
|
+
head_size=self.head_size,
|
156
|
+
cos_sin_cache=self.cos_sin_cache,
|
157
|
+
is_neox=self.is_neox_style,
|
158
|
+
)
|
159
|
+
else:
|
160
|
+
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
161
|
+
ops.rotary_embedding(
|
162
|
+
positions,
|
163
|
+
query,
|
164
|
+
key,
|
165
|
+
self.head_size,
|
166
|
+
self.cos_sin_cache,
|
167
|
+
self.is_neox_style,
|
168
|
+
)
|
155
169
|
return query, key
|
156
170
|
|
157
171
|
def forward_xpu(
|
sglang/srt/layers/sampler.py
CHANGED
@@ -72,9 +72,11 @@ class Sampler(nn.Module):
|
|
72
72
|
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
|
73
73
|
# https://github.com/flashinfer-ai/flashinfer/issues/708
|
74
74
|
# so we use the torch implementation.
|
75
|
+
|
76
|
+
# clamp to avoid -inf
|
75
77
|
logprobs = torch.log(
|
76
78
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
77
|
-
)
|
79
|
+
).clamp(min=torch.finfo(probs.dtype).min)
|
78
80
|
|
79
81
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
80
82
|
uniform_samples = torch.rand(
|
@@ -83,7 +85,7 @@ class Sampler(nn.Module):
|
|
83
85
|
if sampling_info.need_min_p_sampling:
|
84
86
|
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
85
87
|
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
86
|
-
batch_next_token_ids
|
88
|
+
batch_next_token_ids = min_p_sampling_from_probs(
|
87
89
|
probs, uniform_samples, sampling_info.min_ps
|
88
90
|
)
|
89
91
|
else:
|
@@ -95,9 +97,9 @@ class Sampler(nn.Module):
|
|
95
97
|
filter_apply_order="joint",
|
96
98
|
)
|
97
99
|
|
98
|
-
|
99
|
-
|
100
|
-
|
100
|
+
if self.use_nan_detectioin and not torch.all(success):
|
101
|
+
logger.warning("Detected errors during sampling!")
|
102
|
+
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
101
103
|
|
102
104
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
103
105
|
# A slower fallback implementation with torch native operations.
|
@@ -109,9 +111,10 @@ class Sampler(nn.Module):
|
|
109
111
|
sampling_info.need_min_p_sampling,
|
110
112
|
)
|
111
113
|
if return_logprob:
|
114
|
+
# clamp to avoid -inf
|
112
115
|
logprobs = torch.log(
|
113
116
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
114
|
-
)
|
117
|
+
).clamp(min=torch.finfo(probs.dtype).min)
|
115
118
|
else:
|
116
119
|
raise ValueError(
|
117
120
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
@@ -0,0 +1,95 @@
|
|
1
|
+
from typing import Tuple, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.lora.lora import LoraBatchInfo
|
6
|
+
|
7
|
+
|
8
|
+
def get_fuse_output_scaling_add_from_name(name: str) -> bool:
|
9
|
+
mapping = {
|
10
|
+
"triton": True,
|
11
|
+
"flashinfer": False,
|
12
|
+
}
|
13
|
+
return mapping.get(name, False)
|
14
|
+
|
15
|
+
|
16
|
+
def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
|
17
|
+
mapping = {
|
18
|
+
"triton": True,
|
19
|
+
"flashinfer": False,
|
20
|
+
}
|
21
|
+
return mapping.get(name, False)
|
22
|
+
|
23
|
+
|
24
|
+
class BaseLoraBackend:
|
25
|
+
"""Base class for different Lora backends.
|
26
|
+
Each backend has its own implementation of Lora kernels.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
name: name of backend
|
30
|
+
batch_info: information of current batch for use
|
31
|
+
fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
|
32
|
+
and the operation of scaling and adding will be fused into kernel
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
36
|
+
self.name = name
|
37
|
+
self.batch_info = batch_info
|
38
|
+
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
|
39
|
+
self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)
|
40
|
+
|
41
|
+
def run_lora_a_sgemm(
|
42
|
+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
43
|
+
) -> torch.Tensor:
|
44
|
+
"""Run segment Gemm of lora a modules with current backend.
|
45
|
+
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
49
|
+
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
|
50
|
+
usually input_dim is much larger than r
|
51
|
+
Returns:
|
52
|
+
result with shape (s, r)
|
53
|
+
"""
|
54
|
+
pass
|
55
|
+
|
56
|
+
def run_lora_b_sgemm(
|
57
|
+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
58
|
+
) -> torch.Tensor:
|
59
|
+
"""Run segment Gemm of lora b modules with current backend.
|
60
|
+
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
|
64
|
+
weights: a set of lora weights with shape (num_lora, output_dim, r)
|
65
|
+
usually output_dim is much larger than r
|
66
|
+
Returns:
|
67
|
+
result with shape (s, output_dim)
|
68
|
+
"""
|
69
|
+
pass
|
70
|
+
|
71
|
+
def run_qkv_lora(
|
72
|
+
self,
|
73
|
+
x: torch.Tensor,
|
74
|
+
qkv_lora_a: torch.Tensor,
|
75
|
+
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
76
|
+
*args,
|
77
|
+
**kwargs
|
78
|
+
) -> torch.Tensor:
|
79
|
+
"""Run the lora pass for QKV Layer.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
83
|
+
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
|
84
|
+
qkv_lora_b: lora_b module for qkv.
|
85
|
+
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
|
86
|
+
If passed in as a tuple of two tensors containing:
|
87
|
+
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
|
88
|
+
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
|
89
|
+
Returns:
|
90
|
+
result with shape (s, output_dim_q + 2 * output_dim_kv)
|
91
|
+
"""
|
92
|
+
pass
|
93
|
+
|
94
|
+
def set_batch_info(self, batch_info: LoraBatchInfo):
|
95
|
+
self.batch_info = batch_info
|
@@ -0,0 +1,91 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.lora.backend import BaseLoraBackend
|
6
|
+
from sglang.srt.lora.lora import LoraBatchInfo
|
7
|
+
from sglang.srt.utils import is_flashinfer_available
|
8
|
+
|
9
|
+
if is_flashinfer_available():
|
10
|
+
from flashinfer import SegmentGEMMWrapper
|
11
|
+
|
12
|
+
|
13
|
+
class FlashInferLoraBackend(BaseLoraBackend):
|
14
|
+
|
15
|
+
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
16
|
+
super().__init__(name, batch_info)
|
17
|
+
|
18
|
+
# Set up SGemm Wrapper from flashinfer
|
19
|
+
# FIXME wait for flashinfer segment gemm update
|
20
|
+
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
21
|
+
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
|
22
|
+
|
23
|
+
def run_lora_a_sgemm(
|
24
|
+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
25
|
+
) -> torch.Tensor:
|
26
|
+
|
27
|
+
return self.segment_gemm.run(
|
28
|
+
x=x,
|
29
|
+
weights=weights,
|
30
|
+
batch_size=self.batch_info.bs,
|
31
|
+
weight_column_major=True,
|
32
|
+
seg_indptr=self.batch_info.seg_indptr,
|
33
|
+
weight_indices=self.batch_info.weight_indices,
|
34
|
+
)
|
35
|
+
|
36
|
+
def run_lora_b_sgemm(
|
37
|
+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
38
|
+
) -> torch.Tensor:
|
39
|
+
|
40
|
+
return self.segment_gemm.run(
|
41
|
+
x=x,
|
42
|
+
weights=weights,
|
43
|
+
batch_size=self.batch_info.bs,
|
44
|
+
weight_column_major=True,
|
45
|
+
seg_indptr=self.batch_info.seg_indptr,
|
46
|
+
weight_indices=self.batch_info.weight_indices,
|
47
|
+
)
|
48
|
+
|
49
|
+
def run_qkv_lora(
|
50
|
+
self,
|
51
|
+
x: torch.Tensor,
|
52
|
+
qkv_lora_a: torch.Tensor,
|
53
|
+
qkv_lora_b: Tuple[torch.Tensor],
|
54
|
+
*args,
|
55
|
+
**kwargs,
|
56
|
+
) -> torch.Tensor:
|
57
|
+
|
58
|
+
# Shape of lora_a_output: (s, 3 * r)
|
59
|
+
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
|
60
|
+
|
61
|
+
q_lora_b, kv_lora_b = qkv_lora_b
|
62
|
+
lora_rank = kv_lora_b.shape[-1]
|
63
|
+
output_dim_q = q_lora_b.shape[-2]
|
64
|
+
output_dim_kv = kv_lora_b.shape[-2]
|
65
|
+
lora_output = torch.empty(
|
66
|
+
(x.shape[0], output_dim_q + 2 * output_dim_kv),
|
67
|
+
device=x.device,
|
68
|
+
dtype=x.dtype,
|
69
|
+
)
|
70
|
+
|
71
|
+
# q
|
72
|
+
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
|
73
|
+
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
|
74
|
+
)
|
75
|
+
|
76
|
+
# kv
|
77
|
+
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
|
78
|
+
self.run_lora_b_sgemm(
|
79
|
+
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
|
80
|
+
weights=kv_lora_b[0],
|
81
|
+
)
|
82
|
+
)
|
83
|
+
|
84
|
+
lora_output[
|
85
|
+
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
|
86
|
+
] = self.run_lora_b_sgemm(
|
87
|
+
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
|
88
|
+
weights=kv_lora_b[1],
|
89
|
+
)
|
90
|
+
|
91
|
+
return lora_output
|