sglang 0.4.2__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/constrained/outlines_backend.py +9 -1
- sglang/srt/custom_op.py +40 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/layers/activation.py +10 -5
- sglang/srt/layers/attention/flashinfer_backend.py +284 -39
- sglang/srt/layers/attention/triton_backend.py +71 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
- sglang/srt/layers/attention/vision.py +243 -40
- sglang/srt/layers/layernorm.py +1 -5
- sglang/srt/layers/moe/ep_moe/layer.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
- sglang/srt/layers/moe/topk.py +4 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +7 -0
- sglang/srt/layers/quantization/fp8_kernel.py +140 -2
- sglang/srt/layers/rotary_embedding.py +29 -15
- sglang/srt/layers/sampler.py +9 -6
- sglang/srt/lora/backend/__init__.py +8 -0
- sglang/srt/lora/backend/base_backend.py +95 -0
- sglang/srt/lora/backend/flashinfer_backend.py +91 -0
- sglang/srt/lora/backend/triton_backend.py +61 -0
- sglang/srt/lora/lora.py +127 -112
- sglang/srt/lora/lora_manager.py +50 -18
- sglang/srt/lora/triton_ops/__init__.py +5 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
- sglang/srt/managers/image_processor.py +77 -38
- sglang/srt/managers/scheduler.py +17 -3
- sglang/srt/mem_cache/base_prefix_cache.py +4 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +30 -1
- sglang/srt/model_executor/cuda_graph_runner.py +77 -80
- sglang/srt/model_executor/forward_batch_info.py +58 -59
- sglang/srt/model_executor/model_runner.py +2 -2
- sglang/srt/models/minicpmv.py +129 -76
- sglang/srt/models/mllama.py +16 -56
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_vl.py +19 -9
- sglang/srt/server_args.py +19 -2
- sglang/srt/speculative/build_eagle_tree.py +4 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
- sglang/srt/speculative/eagle_utils.py +361 -372
- sglang/srt/speculative/eagle_worker.py +177 -45
- sglang/srt/utils.py +7 -2
- sglang/test/runners.py +2 -0
- sglang/utils.py +42 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +16 -7
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +84 -45
- sglang/srt/layers/custom_op_util.py +0 -25
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,182 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
from sglang.srt.lora.lora import LoraBatchInfo
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def _qkv_lora_b_kernel(
|
10
|
+
# Pointers to matrices
|
11
|
+
x,
|
12
|
+
weights,
|
13
|
+
output,
|
14
|
+
# Parameters of size
|
15
|
+
K, # K = R
|
16
|
+
max_qkv_out_dim, # max(output_q_dim, output_kv_dim)
|
17
|
+
# Strides
|
18
|
+
x_stride_0,
|
19
|
+
x_stride_1,
|
20
|
+
w_stride_0,
|
21
|
+
w_stride_1,
|
22
|
+
w_stride_2,
|
23
|
+
output_stride_0,
|
24
|
+
output_stride_1,
|
25
|
+
# Information on sequence lengths and weight id
|
26
|
+
seg_lens,
|
27
|
+
seg_indptr,
|
28
|
+
weight_indices,
|
29
|
+
# Offsets of q/k/v slice on output dimension
|
30
|
+
n_offs,
|
31
|
+
# Meta parameters
|
32
|
+
BLOCK_S: tl.constexpr,
|
33
|
+
BLOCK_N: tl.constexpr,
|
34
|
+
BLOCK_K: tl.constexpr,
|
35
|
+
# For fused output scaling and adding
|
36
|
+
fuse_scaling_add,
|
37
|
+
scaling,
|
38
|
+
):
|
39
|
+
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
|
40
|
+
|
41
|
+
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
|
42
|
+
# weights: (num_lora, N_Q + 2 * N_KV, K)
|
43
|
+
# output: (s, N_Q + 2 * N_KV)
|
44
|
+
# N_Q >> K, N_KV >> K
|
45
|
+
|
46
|
+
# Current block computes sequence with batch_id,
|
47
|
+
# which starts from row seg_start of x with length seg_len.
|
48
|
+
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
|
49
|
+
batch_id = tl.program_id(axis=2)
|
50
|
+
qkv_id = tl.program_id(axis=1)
|
51
|
+
pid = tl.program_id(axis=0)
|
52
|
+
seg_len = tl.load(seg_lens + batch_id)
|
53
|
+
w_index = tl.load(weight_indices + batch_id)
|
54
|
+
seg_start = tl.load(seg_indptr + batch_id)
|
55
|
+
n_start = tl.load(n_offs + qkv_id)
|
56
|
+
n_size = tl.load(n_offs + qkv_id + 1) - n_start
|
57
|
+
|
58
|
+
# The tile in output matrix will have (pid_s, pid_n) as id
|
59
|
+
num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
|
60
|
+
pid_s = pid // num_pid_n
|
61
|
+
pid_n = pid % num_pid_n
|
62
|
+
|
63
|
+
# Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
|
64
|
+
# The pointers will be advanced as we move in the K direction
|
65
|
+
# and accumulate
|
66
|
+
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
|
67
|
+
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
68
|
+
k_offset = tl.arange(0, BLOCK_K)
|
69
|
+
|
70
|
+
x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + (
|
71
|
+
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
72
|
+
)
|
73
|
+
w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
|
74
|
+
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
75
|
+
)
|
76
|
+
|
77
|
+
# Iteate to compute the block in output matrix
|
78
|
+
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
79
|
+
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
80
|
+
x_tile = tl.load(
|
81
|
+
x_ptrs,
|
82
|
+
mask=(s_offset[:, None] < seg_len)
|
83
|
+
and (k_offset[None, :] < K - k * BLOCK_K),
|
84
|
+
other=0.0,
|
85
|
+
)
|
86
|
+
w_tile = tl.load(
|
87
|
+
w_ptrs,
|
88
|
+
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
|
89
|
+
other=0.0,
|
90
|
+
)
|
91
|
+
partial_sum += tl.dot(x_tile, w_tile)
|
92
|
+
|
93
|
+
x_ptrs += BLOCK_K * x_stride_1
|
94
|
+
w_ptrs += BLOCK_K * w_stride_2
|
95
|
+
|
96
|
+
# Store result to output matrix
|
97
|
+
partial_sum *= scaling
|
98
|
+
partial_sum = partial_sum.to(x.dtype.element_ty)
|
99
|
+
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
100
|
+
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
101
|
+
)
|
102
|
+
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
|
103
|
+
if fuse_scaling_add:
|
104
|
+
partial_sum += tl.load(output_ptr, mask=output_mask)
|
105
|
+
tl.store(output_ptr, partial_sum, mask=output_mask)
|
106
|
+
|
107
|
+
|
108
|
+
def qkv_lora_b_fwd(
|
109
|
+
x: torch.Tensor,
|
110
|
+
qkv_lora_b: torch.Tensor,
|
111
|
+
batch_info: LoraBatchInfo,
|
112
|
+
output_offset: torch.Tensor,
|
113
|
+
max_qkv_out_dim: int,
|
114
|
+
base_output: torch.Tensor = None,
|
115
|
+
scaling: float = 1.0,
|
116
|
+
) -> torch.Tensor:
|
117
|
+
|
118
|
+
# x: (s, 3 * r)
|
119
|
+
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
120
|
+
# output_offset = [0, output_dim_q, output_dim_q + output_dim_kv,
|
121
|
+
# output_dim_q + 2 * output_dim_kv]
|
122
|
+
# max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
123
|
+
# output: (s, output_dim_q + 2 * output_dim_kv)
|
124
|
+
|
125
|
+
# Compute lora_output with shape (s, output_dim) as follows:
|
126
|
+
# lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], )
|
127
|
+
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
|
128
|
+
# = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0])
|
129
|
+
# lora_output[:, output_dim_q + output_dim_kv: ]
|
130
|
+
# = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1])
|
131
|
+
|
132
|
+
# Get dims
|
133
|
+
s = x.shape[0]
|
134
|
+
input_dim = x.shape[1]
|
135
|
+
r = qkv_lora_b.shape[-1]
|
136
|
+
output_dim = qkv_lora_b.shape[-2]
|
137
|
+
assert input_dim == 3 * r
|
138
|
+
assert output_offset.shape[0] == 4
|
139
|
+
|
140
|
+
BLOCK_S = 16
|
141
|
+
BLOCK_R = 16
|
142
|
+
BLOCK_OUT = 64
|
143
|
+
|
144
|
+
grid_b = (
|
145
|
+
triton.cdiv(batch_info.max_len, BLOCK_S)
|
146
|
+
* triton.cdiv(max_qkv_out_dim, BLOCK_OUT),
|
147
|
+
3, # this dimension decides current block computes on q, k or v
|
148
|
+
batch_info.bs,
|
149
|
+
)
|
150
|
+
|
151
|
+
if base_output is None:
|
152
|
+
output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
|
153
|
+
fuse_scaling_add = False
|
154
|
+
else:
|
155
|
+
output = base_output
|
156
|
+
fuse_scaling_add = True
|
157
|
+
|
158
|
+
_qkv_lora_b_kernel[grid_b](
|
159
|
+
x,
|
160
|
+
qkv_lora_b,
|
161
|
+
output,
|
162
|
+
r,
|
163
|
+
max_qkv_out_dim,
|
164
|
+
x.stride(0),
|
165
|
+
x.stride(1),
|
166
|
+
qkv_lora_b.stride(0),
|
167
|
+
qkv_lora_b.stride(1),
|
168
|
+
qkv_lora_b.stride(2),
|
169
|
+
output.stride(0),
|
170
|
+
output.stride(1),
|
171
|
+
batch_info.seg_lens,
|
172
|
+
batch_info.seg_indptr,
|
173
|
+
batch_info.weight_indices,
|
174
|
+
output_offset,
|
175
|
+
BLOCK_S,
|
176
|
+
BLOCK_OUT,
|
177
|
+
BLOCK_R,
|
178
|
+
fuse_scaling_add,
|
179
|
+
scaling,
|
180
|
+
)
|
181
|
+
|
182
|
+
return output
|
@@ -0,0 +1,143 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
from sglang.srt.lora.lora import LoraBatchInfo
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def _sgemm_lora_a_kernel(
|
10
|
+
# Pointers to matrices
|
11
|
+
x,
|
12
|
+
weights,
|
13
|
+
output,
|
14
|
+
# Matrix dimensions
|
15
|
+
N, # r
|
16
|
+
K, # input_dim
|
17
|
+
# Strides
|
18
|
+
x_stride_0,
|
19
|
+
x_stride_1,
|
20
|
+
w_stride_0,
|
21
|
+
w_stride_1,
|
22
|
+
w_stride_2,
|
23
|
+
output_stride_0,
|
24
|
+
output_stride_1,
|
25
|
+
# Information on sequence lengths and weight id
|
26
|
+
seg_lens,
|
27
|
+
seg_indptr,
|
28
|
+
weight_indices,
|
29
|
+
# Meta parameters
|
30
|
+
BLOCK_S: tl.constexpr,
|
31
|
+
BLOCK_N: tl.constexpr,
|
32
|
+
BLOCK_K: tl.constexpr,
|
33
|
+
):
|
34
|
+
|
35
|
+
# x: (s, K), s is the sum of sequence lengths
|
36
|
+
# weights: (num_lora, N, K)
|
37
|
+
# output: (s, N)
|
38
|
+
|
39
|
+
# Current block computes sequence with batch_id,
|
40
|
+
# which starts from row seg_start of x with length seg_len
|
41
|
+
batch_id = tl.program_id(axis=1)
|
42
|
+
pid = tl.program_id(axis=0)
|
43
|
+
seg_len = tl.load(seg_lens + batch_id)
|
44
|
+
w_index = tl.load(weight_indices + batch_id)
|
45
|
+
seg_start = tl.load(seg_indptr + batch_id)
|
46
|
+
|
47
|
+
# The tile in output matrix will have (pid_s, pid_n) as id
|
48
|
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
49
|
+
pid_s = pid // num_pid_n
|
50
|
+
pid_n = pid % num_pid_n
|
51
|
+
|
52
|
+
# Create pointers for the first block of x and weights[batch_id]
|
53
|
+
# The pointers will be advanced as we move in the K direction
|
54
|
+
# and accumulate
|
55
|
+
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
|
56
|
+
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
57
|
+
k_offset = tl.arange(0, BLOCK_K)
|
58
|
+
x_ptrs = (x + seg_start * x_stride_0) + (
|
59
|
+
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
60
|
+
)
|
61
|
+
w_ptrs = (weights + w_index * w_stride_0) + (
|
62
|
+
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
63
|
+
)
|
64
|
+
|
65
|
+
# Iteate to compute the block in output matrix
|
66
|
+
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
67
|
+
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
68
|
+
x_tile = tl.load(
|
69
|
+
x_ptrs,
|
70
|
+
mask=(s_offset[:, None] < seg_len)
|
71
|
+
and (k_offset[None, :] < K - k * BLOCK_K),
|
72
|
+
other=0.0,
|
73
|
+
)
|
74
|
+
w_tile = tl.load(
|
75
|
+
w_ptrs,
|
76
|
+
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
|
77
|
+
other=0.0,
|
78
|
+
)
|
79
|
+
partial_sum += tl.dot(x_tile, w_tile)
|
80
|
+
|
81
|
+
x_ptrs += BLOCK_K * x_stride_1
|
82
|
+
w_ptrs += BLOCK_K * w_stride_2
|
83
|
+
|
84
|
+
# Store result to output matrix
|
85
|
+
partial_sum = partial_sum.to(x.dtype.element_ty)
|
86
|
+
output_ptr = (output + seg_start * output_stride_0) + (
|
87
|
+
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
88
|
+
)
|
89
|
+
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
|
90
|
+
tl.store(output_ptr, partial_sum, mask=output_mask)
|
91
|
+
|
92
|
+
|
93
|
+
def sgemm_lora_a_fwd(
|
94
|
+
x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo
|
95
|
+
) -> torch.Tensor:
|
96
|
+
# x: (s, input_dim)
|
97
|
+
# weights: (num_lora, r, input_dim)
|
98
|
+
# output: (s, r)
|
99
|
+
# when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
|
100
|
+
# input_dim is much larger than r
|
101
|
+
|
102
|
+
assert x.is_contiguous()
|
103
|
+
assert weights.is_contiguous()
|
104
|
+
assert len(x.shape) == 2
|
105
|
+
assert len(weights.shape) == 3
|
106
|
+
|
107
|
+
S = x.shape[0]
|
108
|
+
R = weights.shape[-2]
|
109
|
+
K = weights.shape[-1]
|
110
|
+
assert x.shape[-1] == K
|
111
|
+
|
112
|
+
# Block shapes
|
113
|
+
BLOCK_S = 16
|
114
|
+
BLOCK_K = 256
|
115
|
+
BLOCK_R = 16
|
116
|
+
|
117
|
+
grid = (
|
118
|
+
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(R, BLOCK_R),
|
119
|
+
batch_info.bs,
|
120
|
+
)
|
121
|
+
|
122
|
+
output = torch.empty((S, R), device=x.device, dtype=x.dtype)
|
123
|
+
_sgemm_lora_a_kernel[grid](
|
124
|
+
x,
|
125
|
+
weights,
|
126
|
+
output,
|
127
|
+
R,
|
128
|
+
K,
|
129
|
+
x.stride(0),
|
130
|
+
x.stride(1),
|
131
|
+
weights.stride(0),
|
132
|
+
weights.stride(1),
|
133
|
+
weights.stride(2),
|
134
|
+
output.stride(0),
|
135
|
+
output.stride(1),
|
136
|
+
batch_info.seg_lens,
|
137
|
+
batch_info.seg_indptr,
|
138
|
+
batch_info.weight_indices,
|
139
|
+
BLOCK_S,
|
140
|
+
BLOCK_R,
|
141
|
+
BLOCK_K,
|
142
|
+
)
|
143
|
+
return output
|
@@ -0,0 +1,159 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
from sglang.srt.lora.lora import LoraBatchInfo
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def _sgemm_lora_b_kernel(
|
10
|
+
# Pointers to matrices
|
11
|
+
x,
|
12
|
+
weights,
|
13
|
+
output,
|
14
|
+
# Matrix dimensions
|
15
|
+
N, # output_dim
|
16
|
+
K, # r
|
17
|
+
# Strides
|
18
|
+
x_stride_0,
|
19
|
+
x_stride_1,
|
20
|
+
w_stride_0,
|
21
|
+
w_stride_1,
|
22
|
+
w_stride_2,
|
23
|
+
output_stride_0,
|
24
|
+
output_stride_1,
|
25
|
+
# Information on sequence lengths and weight id
|
26
|
+
seg_lens,
|
27
|
+
seg_indptr,
|
28
|
+
weight_indices,
|
29
|
+
# Meta parameters
|
30
|
+
BLOCK_S: tl.constexpr,
|
31
|
+
BLOCK_N: tl.constexpr,
|
32
|
+
BLOCK_K: tl.constexpr,
|
33
|
+
# For fused output scaling and adding
|
34
|
+
fuse_scaling_add,
|
35
|
+
scaling,
|
36
|
+
):
|
37
|
+
# x: (s, K), s is the sum of sequence lengths
|
38
|
+
# weights: (num_lora, N, K)
|
39
|
+
# output: (s, N)
|
40
|
+
|
41
|
+
# Current block computes sequence with batch_id,
|
42
|
+
# which starts from row seg_start of x with length seg_len
|
43
|
+
batch_id = tl.program_id(axis=1)
|
44
|
+
pid = tl.program_id(axis=0)
|
45
|
+
seg_len = tl.load(seg_lens + batch_id)
|
46
|
+
w_index = tl.load(weight_indices + batch_id)
|
47
|
+
seg_start = tl.load(seg_indptr + batch_id)
|
48
|
+
|
49
|
+
# The tile in output matrix will have (pid_s, pid_n) as id
|
50
|
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
51
|
+
pid_s = pid // num_pid_n
|
52
|
+
pid_n = pid % num_pid_n
|
53
|
+
|
54
|
+
# Create pointers for the first block of x and weights[batch_id]
|
55
|
+
# The pointers will be advanced as we move in the K direction
|
56
|
+
# and accumulate
|
57
|
+
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
|
58
|
+
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
59
|
+
k_offset = tl.arange(0, BLOCK_K)
|
60
|
+
x_ptrs = (x + seg_start * x_stride_0) + (
|
61
|
+
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
62
|
+
)
|
63
|
+
w_ptrs = (weights + w_index * w_stride_0) + (
|
64
|
+
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
65
|
+
)
|
66
|
+
|
67
|
+
# Iteate to compute the block in output matrix
|
68
|
+
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
69
|
+
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
70
|
+
x_tile = tl.load(
|
71
|
+
x_ptrs,
|
72
|
+
mask=(s_offset[:, None] < seg_len)
|
73
|
+
and (k_offset[None, :] < K - k * BLOCK_K),
|
74
|
+
other=0.0,
|
75
|
+
)
|
76
|
+
w_tile = tl.load(
|
77
|
+
w_ptrs,
|
78
|
+
mask=(k_offset[:, None] < K - k * BLOCK_K),
|
79
|
+
other=0.0,
|
80
|
+
)
|
81
|
+
partial_sum += tl.dot(x_tile, w_tile)
|
82
|
+
|
83
|
+
x_ptrs += BLOCK_K * x_stride_1
|
84
|
+
w_ptrs += BLOCK_K * w_stride_2
|
85
|
+
|
86
|
+
# Store result to output matrix
|
87
|
+
partial_sum *= scaling
|
88
|
+
partial_sum = partial_sum.to(x.dtype.element_ty)
|
89
|
+
output_ptr = (output + seg_start * output_stride_0) + (
|
90
|
+
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
91
|
+
)
|
92
|
+
output_mask = s_offset[:, None] < seg_len
|
93
|
+
if fuse_scaling_add:
|
94
|
+
partial_sum += tl.load(output_ptr, mask=output_mask)
|
95
|
+
tl.store(output_ptr, partial_sum, mask=output_mask)
|
96
|
+
|
97
|
+
|
98
|
+
def sgemm_lora_b_fwd(
|
99
|
+
x: torch.Tensor,
|
100
|
+
weights: torch.Tensor,
|
101
|
+
batch_info: LoraBatchInfo,
|
102
|
+
base_output: torch.Tensor = None,
|
103
|
+
scaling: float = 1.0,
|
104
|
+
) -> torch.Tensor:
|
105
|
+
# x: (s, r)
|
106
|
+
# weights: (num_lora, output_dim, r)
|
107
|
+
# output: (s, output_dim)
|
108
|
+
# output_dim is much larger than r
|
109
|
+
|
110
|
+
assert x.is_contiguous()
|
111
|
+
assert weights.is_contiguous()
|
112
|
+
assert len(x.shape) == 2
|
113
|
+
assert len(weights.shape) == 3
|
114
|
+
|
115
|
+
S = x.shape[0]
|
116
|
+
N = weights.shape[-2]
|
117
|
+
R = weights.shape[-1]
|
118
|
+
assert x.shape[-1] == R
|
119
|
+
|
120
|
+
# Block shapes
|
121
|
+
BLOCK_S = 16
|
122
|
+
BLOCK_R = 16
|
123
|
+
BLOCK_N = 256
|
124
|
+
|
125
|
+
grid = (
|
126
|
+
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(N, BLOCK_N),
|
127
|
+
batch_info.bs,
|
128
|
+
)
|
129
|
+
|
130
|
+
if base_output is None:
|
131
|
+
output = torch.empty((S, N), device=x.device, dtype=x.dtype)
|
132
|
+
fuse_scaling_add = False
|
133
|
+
else:
|
134
|
+
output = base_output
|
135
|
+
fuse_scaling_add = True
|
136
|
+
|
137
|
+
_sgemm_lora_b_kernel[grid](
|
138
|
+
x,
|
139
|
+
weights,
|
140
|
+
output,
|
141
|
+
N,
|
142
|
+
R,
|
143
|
+
x.stride(0),
|
144
|
+
x.stride(1),
|
145
|
+
weights.stride(0),
|
146
|
+
weights.stride(1),
|
147
|
+
weights.stride(2),
|
148
|
+
output.stride(0),
|
149
|
+
output.stride(1),
|
150
|
+
batch_info.seg_lens,
|
151
|
+
batch_info.seg_indptr,
|
152
|
+
batch_info.weight_indices,
|
153
|
+
BLOCK_S,
|
154
|
+
BLOCK_N,
|
155
|
+
BLOCK_R,
|
156
|
+
fuse_scaling_add,
|
157
|
+
scaling,
|
158
|
+
)
|
159
|
+
return output
|
@@ -240,6 +240,7 @@ class MllamaImageProcessor(BaseImageProcessor):
|
|
240
240
|
class MiniCPMVImageProcessor(BaseImageProcessor):
|
241
241
|
def __init__(self, hf_config, server_args, _processor):
|
242
242
|
super().__init__(hf_config, server_args, _processor)
|
243
|
+
self.IMAGE_TOKEN = "(<image>./</image>)"
|
243
244
|
|
244
245
|
@staticmethod
|
245
246
|
def _process_images_task(images, input_text):
|
@@ -271,7 +272,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
271
272
|
async def process_images_async(
|
272
273
|
self,
|
273
274
|
image_data: List[Union[str, bytes]],
|
274
|
-
|
275
|
+
input_ids,
|
275
276
|
request_obj,
|
276
277
|
max_req_input_len,
|
277
278
|
):
|
@@ -282,28 +283,49 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
282
283
|
image_data = [image_data]
|
283
284
|
|
284
285
|
image_hashes, image_sizes = [], []
|
285
|
-
|
286
|
-
IMAGE_TOKEN = "(<image>./</image>)"
|
286
|
+
all_frames = []
|
287
287
|
|
288
|
-
# roughly calculate the max number of frames
|
289
|
-
# TODO: the process should be applied to all the visual inputs
|
288
|
+
# roughly calculate the max number of frames under the max_req_input_len limit
|
290
289
|
def calculate_max_num_frames() -> int:
|
291
290
|
# Model-specific
|
292
291
|
NUM_TOKEN_PER_FRAME = 330
|
293
292
|
|
294
|
-
ret = (max_req_input_len - len(
|
293
|
+
ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME
|
295
294
|
return min(ret, 100)
|
296
295
|
|
297
|
-
# if cuda OOM set a smaller number
|
298
296
|
MAX_NUM_FRAMES = calculate_max_num_frames()
|
299
|
-
print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
300
297
|
|
301
|
-
|
298
|
+
# print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
299
|
+
|
300
|
+
def get_estimated_frames_list():
|
301
|
+
"""
|
302
|
+
estimate the total frame count from all visual input
|
303
|
+
"""
|
304
|
+
# Before processing inputs
|
305
|
+
estimated_frames_list = []
|
306
|
+
for image in image_data:
|
307
|
+
if isinstance(image, str) and image.startswith("video:"):
|
308
|
+
path = image[len("video:") :]
|
309
|
+
# Estimate frames for the video
|
310
|
+
vr = VideoReader(path, ctx=cpu(0))
|
311
|
+
num_frames = len(vr)
|
312
|
+
else:
|
313
|
+
# For images, each contributes one frame
|
314
|
+
num_frames = 1
|
315
|
+
estimated_frames_list.append(num_frames)
|
316
|
+
|
317
|
+
return estimated_frames_list
|
318
|
+
|
319
|
+
estimated_frames_list = get_estimated_frames_list()
|
320
|
+
total_frame_count = sum(estimated_frames_list)
|
321
|
+
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
|
322
|
+
|
323
|
+
def encode_video(video_path, frame_count_limit=None):
|
302
324
|
if not os.path.exists(video_path):
|
303
325
|
logger.error(f"Video {video_path} does not exist")
|
304
326
|
return []
|
305
327
|
|
306
|
-
if
|
328
|
+
if frame_count_limit == 0:
|
307
329
|
return []
|
308
330
|
|
309
331
|
def uniform_sample(l, n):
|
@@ -314,45 +336,63 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
314
336
|
vr = VideoReader(video_path, ctx=cpu(0))
|
315
337
|
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
316
338
|
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
317
|
-
if len(frame_idx) >
|
318
|
-
frame_idx = uniform_sample(frame_idx,
|
339
|
+
if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
|
340
|
+
frame_idx = uniform_sample(frame_idx, frame_count_limit)
|
319
341
|
frames = vr.get_batch(frame_idx).asnumpy()
|
320
342
|
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
321
343
|
return frames
|
322
344
|
|
323
|
-
if isinstance(
|
324
|
-
assert len(
|
325
|
-
input_text = self._processor.tokenizer.decode(
|
326
|
-
|
345
|
+
if isinstance(input_ids, list):
|
346
|
+
assert len(input_ids) and isinstance(input_ids[0], int)
|
347
|
+
input_text = self._processor.tokenizer.decode(input_ids)
|
348
|
+
else:
|
349
|
+
input_text = input_ids
|
327
350
|
# MiniCPMV requires each frame of video as a single image token
|
328
|
-
text_parts = input_text.split(IMAGE_TOKEN)
|
351
|
+
text_parts = input_text.split(self.IMAGE_TOKEN)
|
329
352
|
new_text_parts = []
|
330
353
|
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
354
|
+
# Process each input with allocated frames
|
355
|
+
for image_index, (image, estimated_frames) in enumerate(
|
356
|
+
zip(image_data, estimated_frames_list)
|
357
|
+
):
|
358
|
+
if len(all_frames) >= MAX_NUM_FRAMES:
|
359
|
+
frames_to_process = 0
|
360
|
+
else:
|
361
|
+
frames_to_process = max(1, int(estimated_frames * scaling_factor))
|
362
|
+
|
363
|
+
if frames_to_process == 0:
|
364
|
+
frames = []
|
365
|
+
else:
|
366
|
+
try:
|
367
|
+
if isinstance(image, str) and image.startswith("video:"):
|
368
|
+
path = image[len("video:") :]
|
369
|
+
frames = encode_video(path, frame_count_limit=frames_to_process)
|
370
|
+
else:
|
371
|
+
raw_image, _size = load_image(image)
|
372
|
+
frames = [raw_image]
|
373
|
+
if len(frames) == 0:
|
374
|
+
continue
|
375
|
+
except FileNotFoundError as e:
|
376
|
+
print(e)
|
377
|
+
return None
|
378
|
+
image_sizes += frames[0].size * len(frames)
|
379
|
+
image_hashes += [hash(image)] * len(frames)
|
380
|
+
all_frames += frames
|
381
|
+
|
382
|
+
assert frames_to_process == len(frames)
|
383
|
+
|
348
384
|
new_text_parts.append(text_parts[image_index])
|
349
|
-
|
385
|
+
|
386
|
+
if frames_to_process != 0:
|
387
|
+
new_text_parts.append(self.IMAGE_TOKEN * len(frames))
|
350
388
|
|
351
389
|
new_text_parts.append(text_parts[-1])
|
390
|
+
|
352
391
|
input_text = "".join(new_text_parts)
|
353
|
-
|
392
|
+
|
393
|
+
if len(all_frames) == 0:
|
354
394
|
return None
|
355
|
-
res = await self._process_images(images=
|
395
|
+
res = await self._process_images(images=all_frames, input_text=input_text)
|
356
396
|
pixel_values = res["pixel_values"]
|
357
397
|
tgt_sizes = res["tgt_sizes"]
|
358
398
|
input_ids = res["input_ids"]
|
@@ -364,7 +404,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
364
404
|
if tokenizer.slice_start_id:
|
365
405
|
slice_start_id = [tokenizer.slice_start_id]
|
366
406
|
slice_end_id = [tokenizer.slice_end_id]
|
367
|
-
|
368
407
|
return {
|
369
408
|
"input_ids": input_ids.flatten().tolist(),
|
370
409
|
"pixel_values": pixel_values,
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -149,6 +149,7 @@ class Scheduler:
|
|
149
149
|
if not self.spec_algorithm.is_none()
|
150
150
|
else 1
|
151
151
|
)
|
152
|
+
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
152
153
|
|
153
154
|
# Distributed rank info
|
154
155
|
self.dp_size = server_args.dp_size
|
@@ -831,10 +832,16 @@ class Scheduler:
|
|
831
832
|
available_size = (
|
832
833
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
833
834
|
)
|
834
|
-
|
835
|
+
protected_size = self.tree_cache.protected_size()
|
836
|
+
memory_leak = available_size != (
|
837
|
+
self.max_total_num_tokens
|
838
|
+
if not self.enable_hierarchical_cache
|
839
|
+
else self.max_total_num_tokens - protected_size
|
840
|
+
)
|
841
|
+
if memory_leak:
|
835
842
|
msg = (
|
836
843
|
"KV cache pool leak detected!"
|
837
|
-
f"{available_size=}, {self.max_total_num_tokens=}\n"
|
844
|
+
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
838
845
|
)
|
839
846
|
warnings.warn(msg)
|
840
847
|
if crash_on_warnings():
|
@@ -949,7 +956,14 @@ class Scheduler:
|
|
949
956
|
res = adder.add_one_req(req)
|
950
957
|
if res != AddReqResult.CONTINUE:
|
951
958
|
if res == AddReqResult.NO_TOKEN:
|
952
|
-
self.
|
959
|
+
if self.enable_hierarchical_cache:
|
960
|
+
# Set batch_is_full after making sure there are requests that can be served
|
961
|
+
self.batch_is_full = len(adder.can_run_list) > 0 or (
|
962
|
+
self.running_batch is not None
|
963
|
+
and not self.running_batch.is_empty()
|
964
|
+
)
|
965
|
+
else:
|
966
|
+
self.batch_is_full = True
|
953
967
|
break
|
954
968
|
if self.server_args.prefill_only_one_req:
|
955
969
|
break
|