sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/constrained/outlines_backend.py +9 -1
- sglang/srt/custom_op.py +40 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/function_call_parser.py +96 -69
- sglang/srt/layers/activation.py +10 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
- sglang/srt/layers/attention/flashinfer_backend.py +284 -39
- sglang/srt/layers/attention/triton_backend.py +124 -12
- sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
- sglang/srt/layers/layernorm.py +1 -5
- sglang/srt/layers/moe/ep_moe/layer.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
- sglang/srt/layers/moe/topk.py +4 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8_kernel.py +173 -2
- sglang/srt/layers/rotary_embedding.py +1 -3
- sglang/srt/layers/sampler.py +4 -4
- sglang/srt/lora/backend/__init__.py +8 -0
- sglang/srt/lora/backend/base_backend.py +95 -0
- sglang/srt/lora/backend/flashinfer_backend.py +91 -0
- sglang/srt/lora/backend/triton_backend.py +61 -0
- sglang/srt/lora/lora.py +127 -112
- sglang/srt/lora/lora_manager.py +50 -18
- sglang/srt/lora/triton_ops/__init__.py +5 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
- sglang/srt/model_executor/cuda_graph_runner.py +77 -80
- sglang/srt/model_executor/forward_batch_info.py +58 -59
- sglang/srt/model_executor/model_runner.py +2 -2
- sglang/srt/models/llama.py +8 -3
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/server_args.py +13 -2
- sglang/srt/speculative/build_eagle_tree.py +486 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
- sglang/srt/speculative/eagle_utils.py +420 -401
- sglang/srt/speculative/eagle_worker.py +177 -45
- sglang/srt/utils.py +7 -0
- sglang/test/runners.py +2 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +15 -6
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +77 -38
- sglang/srt/layers/custom_op_util.py +0 -25
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,182 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
from sglang.srt.lora.lora import LoraBatchInfo
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def _qkv_lora_b_kernel(
|
10
|
+
# Pointers to matrices
|
11
|
+
x,
|
12
|
+
weights,
|
13
|
+
output,
|
14
|
+
# Parameters of size
|
15
|
+
K, # K = R
|
16
|
+
max_qkv_out_dim, # max(output_q_dim, output_kv_dim)
|
17
|
+
# Strides
|
18
|
+
x_stride_0,
|
19
|
+
x_stride_1,
|
20
|
+
w_stride_0,
|
21
|
+
w_stride_1,
|
22
|
+
w_stride_2,
|
23
|
+
output_stride_0,
|
24
|
+
output_stride_1,
|
25
|
+
# Information on sequence lengths and weight id
|
26
|
+
seg_lens,
|
27
|
+
seg_indptr,
|
28
|
+
weight_indices,
|
29
|
+
# Offsets of q/k/v slice on output dimension
|
30
|
+
n_offs,
|
31
|
+
# Meta parameters
|
32
|
+
BLOCK_S: tl.constexpr,
|
33
|
+
BLOCK_N: tl.constexpr,
|
34
|
+
BLOCK_K: tl.constexpr,
|
35
|
+
# For fused output scaling and adding
|
36
|
+
fuse_scaling_add,
|
37
|
+
scaling,
|
38
|
+
):
|
39
|
+
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
|
40
|
+
|
41
|
+
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
|
42
|
+
# weights: (num_lora, N_Q + 2 * N_KV, K)
|
43
|
+
# output: (s, N_Q + 2 * N_KV)
|
44
|
+
# N_Q >> K, N_KV >> K
|
45
|
+
|
46
|
+
# Current block computes sequence with batch_id,
|
47
|
+
# which starts from row seg_start of x with length seg_len.
|
48
|
+
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
|
49
|
+
batch_id = tl.program_id(axis=2)
|
50
|
+
qkv_id = tl.program_id(axis=1)
|
51
|
+
pid = tl.program_id(axis=0)
|
52
|
+
seg_len = tl.load(seg_lens + batch_id)
|
53
|
+
w_index = tl.load(weight_indices + batch_id)
|
54
|
+
seg_start = tl.load(seg_indptr + batch_id)
|
55
|
+
n_start = tl.load(n_offs + qkv_id)
|
56
|
+
n_size = tl.load(n_offs + qkv_id + 1) - n_start
|
57
|
+
|
58
|
+
# The tile in output matrix will have (pid_s, pid_n) as id
|
59
|
+
num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
|
60
|
+
pid_s = pid // num_pid_n
|
61
|
+
pid_n = pid % num_pid_n
|
62
|
+
|
63
|
+
# Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
|
64
|
+
# The pointers will be advanced as we move in the K direction
|
65
|
+
# and accumulate
|
66
|
+
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
|
67
|
+
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
68
|
+
k_offset = tl.arange(0, BLOCK_K)
|
69
|
+
|
70
|
+
x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + (
|
71
|
+
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
72
|
+
)
|
73
|
+
w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
|
74
|
+
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
75
|
+
)
|
76
|
+
|
77
|
+
# Iteate to compute the block in output matrix
|
78
|
+
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
79
|
+
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
80
|
+
x_tile = tl.load(
|
81
|
+
x_ptrs,
|
82
|
+
mask=(s_offset[:, None] < seg_len)
|
83
|
+
and (k_offset[None, :] < K - k * BLOCK_K),
|
84
|
+
other=0.0,
|
85
|
+
)
|
86
|
+
w_tile = tl.load(
|
87
|
+
w_ptrs,
|
88
|
+
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
|
89
|
+
other=0.0,
|
90
|
+
)
|
91
|
+
partial_sum += tl.dot(x_tile, w_tile)
|
92
|
+
|
93
|
+
x_ptrs += BLOCK_K * x_stride_1
|
94
|
+
w_ptrs += BLOCK_K * w_stride_2
|
95
|
+
|
96
|
+
# Store result to output matrix
|
97
|
+
partial_sum *= scaling
|
98
|
+
partial_sum = partial_sum.to(x.dtype.element_ty)
|
99
|
+
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
100
|
+
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
101
|
+
)
|
102
|
+
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
|
103
|
+
if fuse_scaling_add:
|
104
|
+
partial_sum += tl.load(output_ptr, mask=output_mask)
|
105
|
+
tl.store(output_ptr, partial_sum, mask=output_mask)
|
106
|
+
|
107
|
+
|
108
|
+
def qkv_lora_b_fwd(
|
109
|
+
x: torch.Tensor,
|
110
|
+
qkv_lora_b: torch.Tensor,
|
111
|
+
batch_info: LoraBatchInfo,
|
112
|
+
output_offset: torch.Tensor,
|
113
|
+
max_qkv_out_dim: int,
|
114
|
+
base_output: torch.Tensor = None,
|
115
|
+
scaling: float = 1.0,
|
116
|
+
) -> torch.Tensor:
|
117
|
+
|
118
|
+
# x: (s, 3 * r)
|
119
|
+
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
120
|
+
# output_offset = [0, output_dim_q, output_dim_q + output_dim_kv,
|
121
|
+
# output_dim_q + 2 * output_dim_kv]
|
122
|
+
# max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
123
|
+
# output: (s, output_dim_q + 2 * output_dim_kv)
|
124
|
+
|
125
|
+
# Compute lora_output with shape (s, output_dim) as follows:
|
126
|
+
# lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], )
|
127
|
+
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
|
128
|
+
# = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0])
|
129
|
+
# lora_output[:, output_dim_q + output_dim_kv: ]
|
130
|
+
# = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1])
|
131
|
+
|
132
|
+
# Get dims
|
133
|
+
s = x.shape[0]
|
134
|
+
input_dim = x.shape[1]
|
135
|
+
r = qkv_lora_b.shape[-1]
|
136
|
+
output_dim = qkv_lora_b.shape[-2]
|
137
|
+
assert input_dim == 3 * r
|
138
|
+
assert output_offset.shape[0] == 4
|
139
|
+
|
140
|
+
BLOCK_S = 16
|
141
|
+
BLOCK_R = 16
|
142
|
+
BLOCK_OUT = 64
|
143
|
+
|
144
|
+
grid_b = (
|
145
|
+
triton.cdiv(batch_info.max_len, BLOCK_S)
|
146
|
+
* triton.cdiv(max_qkv_out_dim, BLOCK_OUT),
|
147
|
+
3, # this dimension decides current block computes on q, k or v
|
148
|
+
batch_info.bs,
|
149
|
+
)
|
150
|
+
|
151
|
+
if base_output is None:
|
152
|
+
output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
|
153
|
+
fuse_scaling_add = False
|
154
|
+
else:
|
155
|
+
output = base_output
|
156
|
+
fuse_scaling_add = True
|
157
|
+
|
158
|
+
_qkv_lora_b_kernel[grid_b](
|
159
|
+
x,
|
160
|
+
qkv_lora_b,
|
161
|
+
output,
|
162
|
+
r,
|
163
|
+
max_qkv_out_dim,
|
164
|
+
x.stride(0),
|
165
|
+
x.stride(1),
|
166
|
+
qkv_lora_b.stride(0),
|
167
|
+
qkv_lora_b.stride(1),
|
168
|
+
qkv_lora_b.stride(2),
|
169
|
+
output.stride(0),
|
170
|
+
output.stride(1),
|
171
|
+
batch_info.seg_lens,
|
172
|
+
batch_info.seg_indptr,
|
173
|
+
batch_info.weight_indices,
|
174
|
+
output_offset,
|
175
|
+
BLOCK_S,
|
176
|
+
BLOCK_OUT,
|
177
|
+
BLOCK_R,
|
178
|
+
fuse_scaling_add,
|
179
|
+
scaling,
|
180
|
+
)
|
181
|
+
|
182
|
+
return output
|
@@ -0,0 +1,143 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
from sglang.srt.lora.lora import LoraBatchInfo
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def _sgemm_lora_a_kernel(
|
10
|
+
# Pointers to matrices
|
11
|
+
x,
|
12
|
+
weights,
|
13
|
+
output,
|
14
|
+
# Matrix dimensions
|
15
|
+
N, # r
|
16
|
+
K, # input_dim
|
17
|
+
# Strides
|
18
|
+
x_stride_0,
|
19
|
+
x_stride_1,
|
20
|
+
w_stride_0,
|
21
|
+
w_stride_1,
|
22
|
+
w_stride_2,
|
23
|
+
output_stride_0,
|
24
|
+
output_stride_1,
|
25
|
+
# Information on sequence lengths and weight id
|
26
|
+
seg_lens,
|
27
|
+
seg_indptr,
|
28
|
+
weight_indices,
|
29
|
+
# Meta parameters
|
30
|
+
BLOCK_S: tl.constexpr,
|
31
|
+
BLOCK_N: tl.constexpr,
|
32
|
+
BLOCK_K: tl.constexpr,
|
33
|
+
):
|
34
|
+
|
35
|
+
# x: (s, K), s is the sum of sequence lengths
|
36
|
+
# weights: (num_lora, N, K)
|
37
|
+
# output: (s, N)
|
38
|
+
|
39
|
+
# Current block computes sequence with batch_id,
|
40
|
+
# which starts from row seg_start of x with length seg_len
|
41
|
+
batch_id = tl.program_id(axis=1)
|
42
|
+
pid = tl.program_id(axis=0)
|
43
|
+
seg_len = tl.load(seg_lens + batch_id)
|
44
|
+
w_index = tl.load(weight_indices + batch_id)
|
45
|
+
seg_start = tl.load(seg_indptr + batch_id)
|
46
|
+
|
47
|
+
# The tile in output matrix will have (pid_s, pid_n) as id
|
48
|
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
49
|
+
pid_s = pid // num_pid_n
|
50
|
+
pid_n = pid % num_pid_n
|
51
|
+
|
52
|
+
# Create pointers for the first block of x and weights[batch_id]
|
53
|
+
# The pointers will be advanced as we move in the K direction
|
54
|
+
# and accumulate
|
55
|
+
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
|
56
|
+
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
57
|
+
k_offset = tl.arange(0, BLOCK_K)
|
58
|
+
x_ptrs = (x + seg_start * x_stride_0) + (
|
59
|
+
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
60
|
+
)
|
61
|
+
w_ptrs = (weights + w_index * w_stride_0) + (
|
62
|
+
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
63
|
+
)
|
64
|
+
|
65
|
+
# Iteate to compute the block in output matrix
|
66
|
+
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
67
|
+
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
68
|
+
x_tile = tl.load(
|
69
|
+
x_ptrs,
|
70
|
+
mask=(s_offset[:, None] < seg_len)
|
71
|
+
and (k_offset[None, :] < K - k * BLOCK_K),
|
72
|
+
other=0.0,
|
73
|
+
)
|
74
|
+
w_tile = tl.load(
|
75
|
+
w_ptrs,
|
76
|
+
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
|
77
|
+
other=0.0,
|
78
|
+
)
|
79
|
+
partial_sum += tl.dot(x_tile, w_tile)
|
80
|
+
|
81
|
+
x_ptrs += BLOCK_K * x_stride_1
|
82
|
+
w_ptrs += BLOCK_K * w_stride_2
|
83
|
+
|
84
|
+
# Store result to output matrix
|
85
|
+
partial_sum = partial_sum.to(x.dtype.element_ty)
|
86
|
+
output_ptr = (output + seg_start * output_stride_0) + (
|
87
|
+
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
88
|
+
)
|
89
|
+
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
|
90
|
+
tl.store(output_ptr, partial_sum, mask=output_mask)
|
91
|
+
|
92
|
+
|
93
|
+
def sgemm_lora_a_fwd(
|
94
|
+
x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo
|
95
|
+
) -> torch.Tensor:
|
96
|
+
# x: (s, input_dim)
|
97
|
+
# weights: (num_lora, r, input_dim)
|
98
|
+
# output: (s, r)
|
99
|
+
# when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
|
100
|
+
# input_dim is much larger than r
|
101
|
+
|
102
|
+
assert x.is_contiguous()
|
103
|
+
assert weights.is_contiguous()
|
104
|
+
assert len(x.shape) == 2
|
105
|
+
assert len(weights.shape) == 3
|
106
|
+
|
107
|
+
S = x.shape[0]
|
108
|
+
R = weights.shape[-2]
|
109
|
+
K = weights.shape[-1]
|
110
|
+
assert x.shape[-1] == K
|
111
|
+
|
112
|
+
# Block shapes
|
113
|
+
BLOCK_S = 16
|
114
|
+
BLOCK_K = 256
|
115
|
+
BLOCK_R = 16
|
116
|
+
|
117
|
+
grid = (
|
118
|
+
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(R, BLOCK_R),
|
119
|
+
batch_info.bs,
|
120
|
+
)
|
121
|
+
|
122
|
+
output = torch.empty((S, R), device=x.device, dtype=x.dtype)
|
123
|
+
_sgemm_lora_a_kernel[grid](
|
124
|
+
x,
|
125
|
+
weights,
|
126
|
+
output,
|
127
|
+
R,
|
128
|
+
K,
|
129
|
+
x.stride(0),
|
130
|
+
x.stride(1),
|
131
|
+
weights.stride(0),
|
132
|
+
weights.stride(1),
|
133
|
+
weights.stride(2),
|
134
|
+
output.stride(0),
|
135
|
+
output.stride(1),
|
136
|
+
batch_info.seg_lens,
|
137
|
+
batch_info.seg_indptr,
|
138
|
+
batch_info.weight_indices,
|
139
|
+
BLOCK_S,
|
140
|
+
BLOCK_R,
|
141
|
+
BLOCK_K,
|
142
|
+
)
|
143
|
+
return output
|
@@ -0,0 +1,159 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
from sglang.srt.lora.lora import LoraBatchInfo
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def _sgemm_lora_b_kernel(
|
10
|
+
# Pointers to matrices
|
11
|
+
x,
|
12
|
+
weights,
|
13
|
+
output,
|
14
|
+
# Matrix dimensions
|
15
|
+
N, # output_dim
|
16
|
+
K, # r
|
17
|
+
# Strides
|
18
|
+
x_stride_0,
|
19
|
+
x_stride_1,
|
20
|
+
w_stride_0,
|
21
|
+
w_stride_1,
|
22
|
+
w_stride_2,
|
23
|
+
output_stride_0,
|
24
|
+
output_stride_1,
|
25
|
+
# Information on sequence lengths and weight id
|
26
|
+
seg_lens,
|
27
|
+
seg_indptr,
|
28
|
+
weight_indices,
|
29
|
+
# Meta parameters
|
30
|
+
BLOCK_S: tl.constexpr,
|
31
|
+
BLOCK_N: tl.constexpr,
|
32
|
+
BLOCK_K: tl.constexpr,
|
33
|
+
# For fused output scaling and adding
|
34
|
+
fuse_scaling_add,
|
35
|
+
scaling,
|
36
|
+
):
|
37
|
+
# x: (s, K), s is the sum of sequence lengths
|
38
|
+
# weights: (num_lora, N, K)
|
39
|
+
# output: (s, N)
|
40
|
+
|
41
|
+
# Current block computes sequence with batch_id,
|
42
|
+
# which starts from row seg_start of x with length seg_len
|
43
|
+
batch_id = tl.program_id(axis=1)
|
44
|
+
pid = tl.program_id(axis=0)
|
45
|
+
seg_len = tl.load(seg_lens + batch_id)
|
46
|
+
w_index = tl.load(weight_indices + batch_id)
|
47
|
+
seg_start = tl.load(seg_indptr + batch_id)
|
48
|
+
|
49
|
+
# The tile in output matrix will have (pid_s, pid_n) as id
|
50
|
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
51
|
+
pid_s = pid // num_pid_n
|
52
|
+
pid_n = pid % num_pid_n
|
53
|
+
|
54
|
+
# Create pointers for the first block of x and weights[batch_id]
|
55
|
+
# The pointers will be advanced as we move in the K direction
|
56
|
+
# and accumulate
|
57
|
+
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
|
58
|
+
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
59
|
+
k_offset = tl.arange(0, BLOCK_K)
|
60
|
+
x_ptrs = (x + seg_start * x_stride_0) + (
|
61
|
+
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
62
|
+
)
|
63
|
+
w_ptrs = (weights + w_index * w_stride_0) + (
|
64
|
+
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
65
|
+
)
|
66
|
+
|
67
|
+
# Iteate to compute the block in output matrix
|
68
|
+
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
69
|
+
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
70
|
+
x_tile = tl.load(
|
71
|
+
x_ptrs,
|
72
|
+
mask=(s_offset[:, None] < seg_len)
|
73
|
+
and (k_offset[None, :] < K - k * BLOCK_K),
|
74
|
+
other=0.0,
|
75
|
+
)
|
76
|
+
w_tile = tl.load(
|
77
|
+
w_ptrs,
|
78
|
+
mask=(k_offset[:, None] < K - k * BLOCK_K),
|
79
|
+
other=0.0,
|
80
|
+
)
|
81
|
+
partial_sum += tl.dot(x_tile, w_tile)
|
82
|
+
|
83
|
+
x_ptrs += BLOCK_K * x_stride_1
|
84
|
+
w_ptrs += BLOCK_K * w_stride_2
|
85
|
+
|
86
|
+
# Store result to output matrix
|
87
|
+
partial_sum *= scaling
|
88
|
+
partial_sum = partial_sum.to(x.dtype.element_ty)
|
89
|
+
output_ptr = (output + seg_start * output_stride_0) + (
|
90
|
+
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
91
|
+
)
|
92
|
+
output_mask = s_offset[:, None] < seg_len
|
93
|
+
if fuse_scaling_add:
|
94
|
+
partial_sum += tl.load(output_ptr, mask=output_mask)
|
95
|
+
tl.store(output_ptr, partial_sum, mask=output_mask)
|
96
|
+
|
97
|
+
|
98
|
+
def sgemm_lora_b_fwd(
|
99
|
+
x: torch.Tensor,
|
100
|
+
weights: torch.Tensor,
|
101
|
+
batch_info: LoraBatchInfo,
|
102
|
+
base_output: torch.Tensor = None,
|
103
|
+
scaling: float = 1.0,
|
104
|
+
) -> torch.Tensor:
|
105
|
+
# x: (s, r)
|
106
|
+
# weights: (num_lora, output_dim, r)
|
107
|
+
# output: (s, output_dim)
|
108
|
+
# output_dim is much larger than r
|
109
|
+
|
110
|
+
assert x.is_contiguous()
|
111
|
+
assert weights.is_contiguous()
|
112
|
+
assert len(x.shape) == 2
|
113
|
+
assert len(weights.shape) == 3
|
114
|
+
|
115
|
+
S = x.shape[0]
|
116
|
+
N = weights.shape[-2]
|
117
|
+
R = weights.shape[-1]
|
118
|
+
assert x.shape[-1] == R
|
119
|
+
|
120
|
+
# Block shapes
|
121
|
+
BLOCK_S = 16
|
122
|
+
BLOCK_R = 16
|
123
|
+
BLOCK_N = 256
|
124
|
+
|
125
|
+
grid = (
|
126
|
+
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(N, BLOCK_N),
|
127
|
+
batch_info.bs,
|
128
|
+
)
|
129
|
+
|
130
|
+
if base_output is None:
|
131
|
+
output = torch.empty((S, N), device=x.device, dtype=x.dtype)
|
132
|
+
fuse_scaling_add = False
|
133
|
+
else:
|
134
|
+
output = base_output
|
135
|
+
fuse_scaling_add = True
|
136
|
+
|
137
|
+
_sgemm_lora_b_kernel[grid](
|
138
|
+
x,
|
139
|
+
weights,
|
140
|
+
output,
|
141
|
+
N,
|
142
|
+
R,
|
143
|
+
x.stride(0),
|
144
|
+
x.stride(1),
|
145
|
+
weights.stride(0),
|
146
|
+
weights.stride(1),
|
147
|
+
weights.stride(2),
|
148
|
+
output.stride(0),
|
149
|
+
output.stride(1),
|
150
|
+
batch_info.seg_lens,
|
151
|
+
batch_info.seg_indptr,
|
152
|
+
batch_info.weight_indices,
|
153
|
+
BLOCK_S,
|
154
|
+
BLOCK_N,
|
155
|
+
BLOCK_R,
|
156
|
+
fuse_scaling_add,
|
157
|
+
scaling,
|
158
|
+
)
|
159
|
+
return output
|