sglang 0.4.2.post3__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/check_env.py +1 -0
- sglang/global_config.py +2 -0
- sglang/srt/constrained/outlines_backend.py +4 -1
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/layers/attention/flashinfer_backend.py +265 -147
- sglang/srt/layers/attention/triton_backend.py +358 -72
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/linear.py +12 -5
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +51 -5
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
- sglang/srt/layers/quantization/fp8_kernel.py +123 -17
- sglang/srt/layers/quantization/fp8_utils.py +33 -4
- sglang/srt/lora/backend/__init__.py +25 -5
- sglang/srt/lora/backend/base_backend.py +31 -9
- sglang/srt/lora/backend/flashinfer_backend.py +41 -4
- sglang/srt/lora/backend/triton_backend.py +34 -4
- sglang/srt/lora/layers.py +293 -0
- sglang/srt/lora/lora.py +101 -326
- sglang/srt/lora/lora_manager.py +101 -269
- sglang/srt/lora/mem_pool.py +174 -0
- sglang/srt/lora/triton_ops/__init__.py +7 -1
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
- sglang/srt/lora/utils.py +141 -0
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +16 -3
- sglang/srt/managers/scheduler.py +29 -0
- sglang/srt/managers/tokenizer_manager.py +6 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
- sglang/srt/model_executor/cuda_graph_runner.py +16 -1
- sglang/srt/model_executor/model_runner.py +12 -2
- sglang/srt/models/deepseek_v2.py +17 -7
- sglang/srt/server_args.py +20 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +64 -21
- sglang/srt/speculative/eagle_worker.py +29 -8
- sglang/srt/utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/METADATA +6 -5
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/RECORD +88 -55
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,170 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
from sglang.srt.lora.utils import LoRABatchInfo
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def _gate_up_lora_b_kernel(
|
10
|
+
# Pointers to matrices
|
11
|
+
x,
|
12
|
+
weights,
|
13
|
+
output,
|
14
|
+
# Parameters of size
|
15
|
+
K, # K = R
|
16
|
+
output_dim,
|
17
|
+
# Strides
|
18
|
+
x_stride_0,
|
19
|
+
x_stride_1,
|
20
|
+
w_stride_0,
|
21
|
+
w_stride_1,
|
22
|
+
w_stride_2,
|
23
|
+
output_stride_0,
|
24
|
+
output_stride_1,
|
25
|
+
# Information on sequence lengths and weight id
|
26
|
+
seg_lens,
|
27
|
+
seg_indptr,
|
28
|
+
weight_indices,
|
29
|
+
# Meta parameters
|
30
|
+
BLOCK_S: tl.constexpr,
|
31
|
+
BLOCK_N: tl.constexpr,
|
32
|
+
BLOCK_K: tl.constexpr,
|
33
|
+
# For fused output scaling and adding
|
34
|
+
fuse_scaling_add,
|
35
|
+
scaling,
|
36
|
+
):
|
37
|
+
# This kernel packs 2 sgemms (gate/up) into a single kernel.
|
38
|
+
|
39
|
+
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
|
40
|
+
# weights: (num_lora, 2 * output_dim, K)
|
41
|
+
# output: (s, 2 * output_dim)
|
42
|
+
# output_dim >> K
|
43
|
+
|
44
|
+
# Current block computes sequence with batch_id,
|
45
|
+
# which starts from row seg_start of x with length seg_len.
|
46
|
+
# gate_up_id decides which of gate or up (0: gate, 1: up)
|
47
|
+
batch_id = tl.program_id(axis=2)
|
48
|
+
gate_up_id = tl.program_id(axis=1)
|
49
|
+
pid = tl.program_id(axis=0)
|
50
|
+
seg_len = tl.load(seg_lens + batch_id)
|
51
|
+
w_index = tl.load(weight_indices + batch_id)
|
52
|
+
seg_start = tl.load(seg_indptr + batch_id)
|
53
|
+
n_start = gate_up_id * output_dim # offset on output dim
|
54
|
+
|
55
|
+
# The tile in output matrix will have (pid_s, pid_n) as id
|
56
|
+
num_pid_n = tl.cdiv(output_dim, BLOCK_N)
|
57
|
+
pid_s = pid // num_pid_n
|
58
|
+
pid_n = pid % num_pid_n
|
59
|
+
|
60
|
+
# Create pointers for the first block of x and weights
|
61
|
+
# The pointers will be advanced as we move in the K direction
|
62
|
+
# and accumulate
|
63
|
+
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
|
64
|
+
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
65
|
+
k_offset = tl.arange(0, BLOCK_K)
|
66
|
+
|
67
|
+
x_ptrs = (x + seg_start * x_stride_0 + (gate_up_id * K) * x_stride_1) + (
|
68
|
+
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
69
|
+
)
|
70
|
+
w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
|
71
|
+
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
72
|
+
)
|
73
|
+
|
74
|
+
# Iteate to compute the block in output matrix
|
75
|
+
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
76
|
+
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
77
|
+
x_tile = tl.load(
|
78
|
+
x_ptrs,
|
79
|
+
mask=(s_offset[:, None] < seg_len)
|
80
|
+
and (k_offset[None, :] < K - k * BLOCK_K),
|
81
|
+
other=0.0,
|
82
|
+
)
|
83
|
+
w_tile = tl.load(
|
84
|
+
w_ptrs,
|
85
|
+
mask=(k_offset[:, None] < K - k * BLOCK_K)
|
86
|
+
and (n_offset[None, :] < output_dim),
|
87
|
+
other=0.0,
|
88
|
+
)
|
89
|
+
partial_sum += tl.dot(x_tile, w_tile)
|
90
|
+
|
91
|
+
x_ptrs += BLOCK_K * x_stride_1
|
92
|
+
w_ptrs += BLOCK_K * w_stride_2
|
93
|
+
|
94
|
+
# Store result to output matrix
|
95
|
+
partial_sum *= scaling
|
96
|
+
partial_sum = partial_sum.to(x.dtype.element_ty)
|
97
|
+
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
98
|
+
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
99
|
+
)
|
100
|
+
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
|
101
|
+
if fuse_scaling_add:
|
102
|
+
partial_sum += tl.load(output_ptr, mask=output_mask)
|
103
|
+
tl.store(output_ptr, partial_sum, mask=output_mask)
|
104
|
+
|
105
|
+
|
106
|
+
def gate_up_lora_b_fwd(
|
107
|
+
x: torch.Tensor,
|
108
|
+
gate_up_lora_b: torch.Tensor,
|
109
|
+
batch_info: LoRABatchInfo,
|
110
|
+
output_dim: int,
|
111
|
+
base_output: torch.Tensor = None,
|
112
|
+
scaling: float = 1.0,
|
113
|
+
) -> torch.Tensor:
|
114
|
+
|
115
|
+
# x: (s, 2 * r)
|
116
|
+
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
|
117
|
+
# output: (s, 2 * output_dim)
|
118
|
+
|
119
|
+
# Compute lora_output with shape (s, output_dim) as follows:
|
120
|
+
# lora_output[:, :output_dim] = sgemm(x[:, :r], gate_up_lora_b[:, :output_dim, :])
|
121
|
+
# lora_output[:, output_dim:]
|
122
|
+
# = sgemm(x[:, r:], gate_up_lora_b[:, output_dim:, :])
|
123
|
+
|
124
|
+
# Get dims
|
125
|
+
s = x.shape[0]
|
126
|
+
input_dim = x.shape[1]
|
127
|
+
r = gate_up_lora_b.shape[-1]
|
128
|
+
assert input_dim == 2 * r
|
129
|
+
|
130
|
+
BLOCK_S = 16
|
131
|
+
BLOCK_R = 16
|
132
|
+
BLOCK_OUT = 64
|
133
|
+
|
134
|
+
grid_b = (
|
135
|
+
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(output_dim, BLOCK_OUT),
|
136
|
+
2, # this dimension decides current block computes on gate or up proj
|
137
|
+
batch_info.bs,
|
138
|
+
)
|
139
|
+
|
140
|
+
if base_output is None:
|
141
|
+
output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
|
142
|
+
fuse_scaling_add = False
|
143
|
+
else:
|
144
|
+
output = base_output
|
145
|
+
fuse_scaling_add = True
|
146
|
+
|
147
|
+
_gate_up_lora_b_kernel[grid_b](
|
148
|
+
x,
|
149
|
+
gate_up_lora_b,
|
150
|
+
output,
|
151
|
+
r,
|
152
|
+
output_dim,
|
153
|
+
x.stride(0),
|
154
|
+
x.stride(1),
|
155
|
+
gate_up_lora_b.stride(0),
|
156
|
+
gate_up_lora_b.stride(1),
|
157
|
+
gate_up_lora_b.stride(2),
|
158
|
+
output.stride(0),
|
159
|
+
output.stride(1),
|
160
|
+
batch_info.seg_lens,
|
161
|
+
batch_info.seg_indptr,
|
162
|
+
batch_info.weight_indices,
|
163
|
+
BLOCK_S,
|
164
|
+
BLOCK_OUT,
|
165
|
+
BLOCK_R,
|
166
|
+
fuse_scaling_add,
|
167
|
+
scaling,
|
168
|
+
)
|
169
|
+
|
170
|
+
return output
|
@@ -2,7 +2,7 @@ import torch
|
|
2
2
|
import triton
|
3
3
|
import triton.language as tl
|
4
4
|
|
5
|
-
from sglang.srt.lora.
|
5
|
+
from sglang.srt.lora.utils import LoRABatchInfo
|
6
6
|
|
7
7
|
|
8
8
|
@triton.jit
|
@@ -108,7 +108,7 @@ def _qkv_lora_b_kernel(
|
|
108
108
|
def qkv_lora_b_fwd(
|
109
109
|
x: torch.Tensor,
|
110
110
|
qkv_lora_b: torch.Tensor,
|
111
|
-
batch_info:
|
111
|
+
batch_info: LoRABatchInfo,
|
112
112
|
output_offset: torch.Tensor,
|
113
113
|
max_qkv_out_dim: int,
|
114
114
|
base_output: torch.Tensor = None,
|
@@ -123,11 +123,11 @@ def qkv_lora_b_fwd(
|
|
123
123
|
# output: (s, output_dim_q + 2 * output_dim_kv)
|
124
124
|
|
125
125
|
# Compute lora_output with shape (s, output_dim) as follows:
|
126
|
-
# lora_output[:, :output_dim_q] = sgemm(
|
126
|
+
# lora_output[:, :output_dim_q] = sgemm(x[:, :r], qkv_lora_b[:, :outptu_dim_q, :])
|
127
127
|
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
|
128
|
-
# = sgemm(
|
128
|
+
# = sgemm(x[:, r: 2 * r], qkv_lora_b[:, outptu_dim_q: output_dim_q + output_dim_kv, :])
|
129
129
|
# lora_output[:, output_dim_q + output_dim_kv: ]
|
130
|
-
# = sgemm(
|
130
|
+
# = sgemm(x[:, 2 * r: , qkv_lora_b[:, output_dim_q + output_dim_kv: , :])
|
131
131
|
|
132
132
|
# Get dims
|
133
133
|
s = x.shape[0]
|
@@ -2,7 +2,7 @@ import torch
|
|
2
2
|
import triton
|
3
3
|
import triton.language as tl
|
4
4
|
|
5
|
-
from sglang.srt.lora.
|
5
|
+
from sglang.srt.lora.utils import LoRABatchInfo
|
6
6
|
|
7
7
|
|
8
8
|
@triton.jit
|
@@ -91,7 +91,7 @@ def _sgemm_lora_a_kernel(
|
|
91
91
|
|
92
92
|
|
93
93
|
def sgemm_lora_a_fwd(
|
94
|
-
x: torch.Tensor, weights: torch.Tensor, batch_info:
|
94
|
+
x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo
|
95
95
|
) -> torch.Tensor:
|
96
96
|
# x: (s, input_dim)
|
97
97
|
# weights: (num_lora, r, input_dim)
|
@@ -2,7 +2,7 @@ import torch
|
|
2
2
|
import triton
|
3
3
|
import triton.language as tl
|
4
4
|
|
5
|
-
from sglang.srt.lora.
|
5
|
+
from sglang.srt.lora.utils import LoRABatchInfo
|
6
6
|
|
7
7
|
|
8
8
|
@triton.jit
|
@@ -98,7 +98,7 @@ def _sgemm_lora_b_kernel(
|
|
98
98
|
def sgemm_lora_b_fwd(
|
99
99
|
x: torch.Tensor,
|
100
100
|
weights: torch.Tensor,
|
101
|
-
batch_info:
|
101
|
+
batch_info: LoRABatchInfo,
|
102
102
|
base_output: torch.Tensor = None,
|
103
103
|
scaling: float = 1.0,
|
104
104
|
) -> torch.Tensor:
|
sglang/srt/lora/utils.py
ADDED
@@ -0,0 +1,141 @@
|
|
1
|
+
import re
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from enum import Enum
|
4
|
+
from typing import Optional, Set, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.hf_transformers_utils import AutoConfig
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class LoRABatchInfo:
|
13
|
+
# Batch size
|
14
|
+
bs: int
|
15
|
+
|
16
|
+
# Lengths of each sequence in shape (bs,)
|
17
|
+
seg_lens: torch.Tensor
|
18
|
+
|
19
|
+
# Indice pointers of each sequence in shape (bs + 1, )
|
20
|
+
seg_indptr: torch.Tensor
|
21
|
+
|
22
|
+
# Maximum sequence length of current batch
|
23
|
+
max_len: int
|
24
|
+
|
25
|
+
# The index of lora adapter used by each sequence, in shape (bs,)
|
26
|
+
weight_indices: torch.Tensor
|
27
|
+
|
28
|
+
|
29
|
+
class LoRAType(Enum):
|
30
|
+
LORA_A = 0
|
31
|
+
LORA_B = 1
|
32
|
+
|
33
|
+
|
34
|
+
def get_layer_id(name: str) -> int:
|
35
|
+
"""
|
36
|
+
Extract integer id of layer from its name in string.
|
37
|
+
"""
|
38
|
+
match = re.search(r"layers\.(\d+)\.", name)
|
39
|
+
if match is None:
|
40
|
+
return None
|
41
|
+
return int(match.group(1))
|
42
|
+
|
43
|
+
|
44
|
+
def get_customized_names_from_hf_names(
|
45
|
+
hf_module_names: Set[str], base_model: torch.nn.Module
|
46
|
+
) -> Set[str]:
|
47
|
+
"""
|
48
|
+
This function takes in a set of huggingface style module names:
|
49
|
+
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
50
|
+
and outputs a set of module names of customized sglang layers:
|
51
|
+
e.g., {"qkv_proj", "o_proj"}
|
52
|
+
"""
|
53
|
+
if hasattr(base_model, "get_module_name"):
|
54
|
+
return {base_model.get_module_name(name) for name in hf_module_names}
|
55
|
+
else:
|
56
|
+
"""
|
57
|
+
Fallback solution of mapping from config module name to module name in model class.
|
58
|
+
Please check if it aligns with your base model.
|
59
|
+
Please implement the function in the model class if it is not.
|
60
|
+
You can reference this function in llama.py.
|
61
|
+
"""
|
62
|
+
params_mapping = {
|
63
|
+
"q_proj": "qkv_proj",
|
64
|
+
"k_proj": "qkv_proj",
|
65
|
+
"v_proj": "qkv_proj",
|
66
|
+
"gate_proj": "gate_up_proj",
|
67
|
+
"up_proj": "gate_up_proj",
|
68
|
+
}
|
69
|
+
return {params_mapping.get(name, name) for name in hf_module_names}
|
70
|
+
|
71
|
+
|
72
|
+
def get_hidden_dim(
|
73
|
+
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
74
|
+
) -> Tuple[int]:
|
75
|
+
"""
|
76
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
|
77
|
+
"""
|
78
|
+
|
79
|
+
if hasattr(base_model, "get_hidden_dim"):
|
80
|
+
return base_model.get_hidden_dim(module_name)
|
81
|
+
else:
|
82
|
+
"""
|
83
|
+
WARNING: get_hidden_dim() is not defined,
|
84
|
+
which is used to get the hidden dim for different lora modules
|
85
|
+
Use the default one, but please check if it is correct for your model.
|
86
|
+
Please implement the function in the model class if it is not.
|
87
|
+
You can reference this function in llama.py.
|
88
|
+
"""
|
89
|
+
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
90
|
+
return config.hidden_size, config.hidden_size
|
91
|
+
elif module_name in ["kv_proj"]:
|
92
|
+
return config.hidden_size, config.hidden_size // (
|
93
|
+
config.num_attention_heads // config.num_key_value_heads
|
94
|
+
)
|
95
|
+
elif module_name == "gate_up_proj":
|
96
|
+
return config.hidden_size, config.intermediate_size
|
97
|
+
elif module_name == "down_proj":
|
98
|
+
return config.intermediate_size, config.hidden_size
|
99
|
+
else:
|
100
|
+
raise NotImplementedError()
|
101
|
+
|
102
|
+
|
103
|
+
def get_stacked_name(name: str) -> Tuple[str]:
|
104
|
+
"""
|
105
|
+
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
|
106
|
+
"""
|
107
|
+
params_mapping = {
|
108
|
+
"q_proj": ("qkv_proj", "q_proj"),
|
109
|
+
"k_proj": ("qkv_proj", "kv_proj"),
|
110
|
+
"v_proj": ("qkv_proj", "kv_proj"),
|
111
|
+
"gate_proj": ("gate_up_proj", "gate_up_proj"),
|
112
|
+
"up_proj": ("gate_up_proj", "gate_up_proj"),
|
113
|
+
}
|
114
|
+
return params_mapping.get(name, (name, name))
|
115
|
+
|
116
|
+
|
117
|
+
def get_stacked_multiply(module_name: str) -> int:
|
118
|
+
"""
|
119
|
+
Mapping a lora module name to its magnification at output dimension
|
120
|
+
"""
|
121
|
+
stacked_rank = {
|
122
|
+
"qkv_proj": 3,
|
123
|
+
"kv_proj": 2,
|
124
|
+
"gate_up_proj": 2,
|
125
|
+
}
|
126
|
+
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
127
|
+
|
128
|
+
|
129
|
+
def get_weight_name(
|
130
|
+
target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType
|
131
|
+
) -> Optional[str]:
|
132
|
+
"""
|
133
|
+
target_name is name of a given module,
|
134
|
+
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
|
135
|
+
If there is a weight name in lora_weight_names that can match target_name, return this name
|
136
|
+
Else return None
|
137
|
+
"""
|
138
|
+
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
139
|
+
for weight_name_pair in lora_weight_names:
|
140
|
+
if weight_name_pair[idx] in target_name:
|
141
|
+
return weight_name_pair[idx]
|
@@ -210,6 +210,7 @@ class DetokenizerManager:
|
|
210
210
|
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
211
211
|
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
212
212
|
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
213
|
+
output_hidden_states=recv_obj.output_hidden_states,
|
213
214
|
)
|
214
215
|
)
|
215
216
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -371,6 +371,8 @@ class BatchTokenIDOut:
|
|
371
371
|
output_top_logprobs_val: List[List]
|
372
372
|
output_top_logprobs_idx: List[List]
|
373
373
|
|
374
|
+
output_hidden_states: List[List[float]]
|
375
|
+
|
374
376
|
|
375
377
|
@dataclass
|
376
378
|
class BatchStrOut:
|
@@ -397,6 +399,8 @@ class BatchStrOut:
|
|
397
399
|
output_top_logprobs_val: List[List]
|
398
400
|
output_top_logprobs_idx: List[List]
|
399
401
|
|
402
|
+
output_hidden_states: List[List[float]]
|
403
|
+
|
400
404
|
|
401
405
|
@dataclass
|
402
406
|
class BatchEmbeddingOut:
|
@@ -65,6 +65,7 @@ global_server_args_dict = {
|
|
65
65
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
66
66
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
67
67
|
"device": ServerArgs.device,
|
68
|
+
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
68
69
|
}
|
69
70
|
|
70
71
|
logger = logging.getLogger(__name__)
|
@@ -315,6 +316,7 @@ class Req:
|
|
315
316
|
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
316
317
|
self.output_top_logprobs_val
|
317
318
|
) = self.output_top_logprobs_idx = None
|
319
|
+
self.hidden_states = []
|
318
320
|
|
319
321
|
# Logprobs (internal values)
|
320
322
|
# The tokens is prefilled but need to be considered as decode tokens
|
@@ -604,6 +606,9 @@ class ScheduleBatch:
|
|
604
606
|
# Enable custom logit processor
|
605
607
|
enable_custom_logit_processor: bool = False
|
606
608
|
|
609
|
+
# Return hidden states
|
610
|
+
return_hidden_states: bool = False
|
611
|
+
|
607
612
|
@classmethod
|
608
613
|
def init_new(
|
609
614
|
cls,
|
@@ -615,6 +620,7 @@ class ScheduleBatch:
|
|
615
620
|
enable_overlap: bool,
|
616
621
|
spec_algorithm: SpeculativeAlgorithm,
|
617
622
|
enable_custom_logit_processor: bool,
|
623
|
+
return_hidden_states: bool = False,
|
618
624
|
):
|
619
625
|
return cls(
|
620
626
|
reqs=reqs,
|
@@ -629,6 +635,7 @@ class ScheduleBatch:
|
|
629
635
|
device=req_to_token_pool.device,
|
630
636
|
spec_algorithm=spec_algorithm,
|
631
637
|
enable_custom_logit_processor=enable_custom_logit_processor,
|
638
|
+
return_hidden_states=return_hidden_states,
|
632
639
|
)
|
633
640
|
|
634
641
|
def batch_size(self):
|
@@ -1196,9 +1203,15 @@ class ScheduleBatch:
|
|
1196
1203
|
spec_algorithm=self.spec_algorithm,
|
1197
1204
|
spec_info=self.spec_info,
|
1198
1205
|
capture_hidden_mode=(
|
1199
|
-
|
1200
|
-
if self.
|
1201
|
-
else
|
1206
|
+
CaptureHiddenMode.FULL
|
1207
|
+
if self.return_hidden_states
|
1208
|
+
else (
|
1209
|
+
getattr(
|
1210
|
+
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
1211
|
+
)
|
1212
|
+
if self.spec_info
|
1213
|
+
else CaptureHiddenMode.NULL
|
1214
|
+
)
|
1202
1215
|
),
|
1203
1216
|
)
|
1204
1217
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -997,6 +997,7 @@ class Scheduler:
|
|
997
997
|
self.enable_overlap,
|
998
998
|
self.spec_algorithm,
|
999
999
|
self.server_args.enable_custom_logit_processor,
|
1000
|
+
self.server_args.return_hidden_states,
|
1000
1001
|
)
|
1001
1002
|
new_batch.prepare_for_extend()
|
1002
1003
|
|
@@ -1156,6 +1157,8 @@ class Scheduler:
|
|
1156
1157
|
logits_output.input_token_logprobs.tolist()
|
1157
1158
|
)
|
1158
1159
|
|
1160
|
+
hidden_state_offset = 0
|
1161
|
+
|
1159
1162
|
# Check finish conditions
|
1160
1163
|
logprob_pt = 0
|
1161
1164
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
@@ -1182,6 +1185,21 @@ class Scheduler:
|
|
1182
1185
|
i, req, logprob_pt, next_token_ids, logits_output
|
1183
1186
|
)
|
1184
1187
|
|
1188
|
+
if (
|
1189
|
+
self.server_args.return_hidden_states
|
1190
|
+
and logits_output.hidden_states is not None
|
1191
|
+
):
|
1192
|
+
req.hidden_states.append(
|
1193
|
+
logits_output.hidden_states[
|
1194
|
+
hidden_state_offset : (
|
1195
|
+
hidden_state_offset := hidden_state_offset
|
1196
|
+
+ len(req.origin_input_ids)
|
1197
|
+
)
|
1198
|
+
]
|
1199
|
+
.cpu()
|
1200
|
+
.clone()
|
1201
|
+
)
|
1202
|
+
|
1185
1203
|
if req.grammar is not None:
|
1186
1204
|
req.grammar.accept_token(next_token_id)
|
1187
1205
|
req.grammar.finished = req.finished()
|
@@ -1275,6 +1293,12 @@ class Scheduler:
|
|
1275
1293
|
logits_output.next_token_top_logprobs_idx[i]
|
1276
1294
|
)
|
1277
1295
|
|
1296
|
+
if (
|
1297
|
+
self.server_args.return_hidden_states
|
1298
|
+
and logits_output.hidden_states is not None
|
1299
|
+
):
|
1300
|
+
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
1301
|
+
|
1278
1302
|
if req.grammar is not None:
|
1279
1303
|
req.grammar.accept_token(next_token_id)
|
1280
1304
|
req.grammar.finished = req.finished()
|
@@ -1398,6 +1422,7 @@ class Scheduler:
|
|
1398
1422
|
completion_tokens = []
|
1399
1423
|
cached_tokens = []
|
1400
1424
|
spec_verify_ct = []
|
1425
|
+
hidden_states = []
|
1401
1426
|
|
1402
1427
|
if return_logprob:
|
1403
1428
|
input_token_logprobs_val = []
|
@@ -1464,6 +1489,8 @@ class Scheduler:
|
|
1464
1489
|
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
1465
1490
|
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
1466
1491
|
|
1492
|
+
hidden_states.append(req.hidden_states)
|
1493
|
+
|
1467
1494
|
# Send to detokenizer
|
1468
1495
|
if rids:
|
1469
1496
|
self.send_to_detokenizer.send_pyobj(
|
@@ -1490,6 +1517,7 @@ class Scheduler:
|
|
1490
1517
|
input_top_logprobs_idx,
|
1491
1518
|
output_top_logprobs_val,
|
1492
1519
|
output_top_logprobs_idx,
|
1520
|
+
hidden_states,
|
1493
1521
|
)
|
1494
1522
|
)
|
1495
1523
|
else: # embedding or reward model
|
@@ -1553,6 +1581,7 @@ class Scheduler:
|
|
1553
1581
|
self.enable_overlap,
|
1554
1582
|
self.spec_algorithm,
|
1555
1583
|
self.server_args.enable_custom_logit_processor,
|
1584
|
+
self.server_args.return_hidden_states,
|
1556
1585
|
)
|
1557
1586
|
idle_batch.prepare_for_idle()
|
1558
1587
|
return idle_batch
|
@@ -796,6 +796,12 @@ class TokenizerManager:
|
|
796
796
|
}
|
797
797
|
)
|
798
798
|
|
799
|
+
if (
|
800
|
+
hasattr(recv_obj, "output_hidden_states")
|
801
|
+
and len(recv_obj.output_hidden_states[i]) > 0
|
802
|
+
):
|
803
|
+
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
804
|
+
|
799
805
|
if isinstance(recv_obj, BatchStrOut):
|
800
806
|
out_dict = {
|
801
807
|
"text": recv_obj.output_strs[i],
|
@@ -156,6 +156,10 @@ class TpModelWorkerClient:
|
|
156
156
|
logits_output.input_token_logprobs = (
|
157
157
|
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
158
158
|
)
|
159
|
+
if logits_output.hidden_states is not None:
|
160
|
+
logits_output.hidden_states = logits_output.hidden_states.to(
|
161
|
+
"cpu", non_blocking=True
|
162
|
+
)
|
159
163
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
160
164
|
copy_done.record()
|
161
165
|
|
@@ -33,6 +33,9 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
33
33
|
ForwardBatch,
|
34
34
|
ForwardMode,
|
35
35
|
)
|
36
|
+
from sglang.srt.utils import is_hip
|
37
|
+
|
38
|
+
is_hip_ = is_hip()
|
36
39
|
|
37
40
|
if TYPE_CHECKING:
|
38
41
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -129,6 +132,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
129
132
|
if bs <= model_runner.req_to_token_pool.size
|
130
133
|
and bs <= server_args.cuda_graph_max_bs
|
131
134
|
]
|
135
|
+
if is_hip_:
|
136
|
+
capture_bs += [i * 8 for i in range(21, 33)]
|
132
137
|
compile_bs = (
|
133
138
|
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
134
139
|
if server_args.enable_torch_compile
|
@@ -237,6 +242,7 @@ class CudaGraphRunner:
|
|
237
242
|
"1. disable cuda graph by --disable-cuda-graph\n"
|
238
243
|
"2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
239
244
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
245
|
+
"4. set --cuda-graph-max-bs to a smaller value (e.g., 32)\n"
|
240
246
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
241
247
|
)
|
242
248
|
|
@@ -348,7 +354,13 @@ class CudaGraphRunner:
|
|
348
354
|
spec_algorithm=self.model_runner.spec_algorithm,
|
349
355
|
spec_info=spec_info,
|
350
356
|
capture_hidden_mode=(
|
351
|
-
|
357
|
+
CaptureHiddenMode.FULL
|
358
|
+
if self.model_runner.server_args.return_hidden_states
|
359
|
+
else (
|
360
|
+
spec_info.capture_hidden_mode
|
361
|
+
if spec_info
|
362
|
+
else CaptureHiddenMode.NULL
|
363
|
+
)
|
352
364
|
),
|
353
365
|
)
|
354
366
|
|
@@ -462,8 +474,11 @@ class CudaGraphRunner:
|
|
462
474
|
),
|
463
475
|
positions=None,
|
464
476
|
retrive_index=None,
|
477
|
+
retrive_next_token=None,
|
478
|
+
retrive_next_sibling=None,
|
465
479
|
retrive_cum_len=None,
|
466
480
|
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
481
|
+
spec_steps=self.model_runner.server_args.speculative_num_steps,
|
467
482
|
capture_hidden_mode=CaptureHiddenMode.FULL,
|
468
483
|
)
|
469
484
|
|
@@ -67,6 +67,7 @@ from sglang.srt.utils import (
|
|
67
67
|
monkey_patch_p2p_access_check,
|
68
68
|
monkey_patch_vllm_gguf_config,
|
69
69
|
set_cpu_offload_max_bytes,
|
70
|
+
set_cuda_arch,
|
70
71
|
)
|
71
72
|
|
72
73
|
logger = logging.getLogger(__name__)
|
@@ -110,8 +111,14 @@ class ModelRunner:
|
|
110
111
|
):
|
111
112
|
# TODO: add MLA optimization on CPU
|
112
113
|
if self.server_args.device != "cpu":
|
113
|
-
|
114
|
-
|
114
|
+
if server_args.enable_flashinfer_mla:
|
115
|
+
logger.info(
|
116
|
+
"FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
|
117
|
+
)
|
118
|
+
self.server_args.attention_backend = "flashinfer"
|
119
|
+
else:
|
120
|
+
logger.info("MLA optimization is turned on. Use triton backend.")
|
121
|
+
self.server_args.attention_backend = "triton"
|
115
122
|
|
116
123
|
if self.server_args.enable_double_sparsity:
|
117
124
|
logger.info(
|
@@ -169,6 +176,7 @@ class ModelRunner:
|
|
169
176
|
"enable_dp_attention": server_args.enable_dp_attention,
|
170
177
|
"enable_ep_moe": server_args.enable_ep_moe,
|
171
178
|
"device": server_args.device,
|
179
|
+
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
172
180
|
}
|
173
181
|
)
|
174
182
|
|
@@ -292,6 +300,8 @@ class ModelRunner:
|
|
292
300
|
if torch.cuda.get_device_capability()[1] < 5:
|
293
301
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
294
302
|
|
303
|
+
set_cuda_arch()
|
304
|
+
|
295
305
|
# Prepare the model config
|
296
306
|
self.load_config = LoadConfig(
|
297
307
|
load_format=self.server_args.load_format,
|