sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +22 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +215 -83
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/linear.py +159 -55
- sglang/srt/layers/logits_processor.py +170 -215
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
- sglang/srt/layers/parameter.py +431 -0
- sglang/srt/layers/quantization/__init__.py +3 -2
- sglang/srt/layers/quantization/fp8.py +3 -3
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -1
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +1 -2
- sglang/srt/managers/schedule_batch.py +33 -3
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +68 -28
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +27 -21
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/memory_pool.py +206 -1
- sglang/srt/metrics/collector.py +22 -30
- sglang/srt/model_executor/cuda_graph_runner.py +129 -77
- sglang/srt/model_executor/forward_batch_info.py +51 -21
- sglang/srt/model_executor/model_runner.py +72 -64
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +109 -29
- sglang/srt/models/llama.py +9 -2
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +22 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +20 -13
- sglang/srt/server_args.py +120 -58
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +626 -0
- sglang/srt/speculative/eagle_worker.py +184 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +47 -7
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,347 @@
|
|
1
|
+
import cutex
|
2
|
+
import torch
|
3
|
+
|
4
|
+
# parent_table [bs,topk*depth+)]
|
5
|
+
# selected_index [bs,draft_token_num-1)]
|
6
|
+
# verified_seq_len [bs]
|
7
|
+
# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
|
8
|
+
# positions [bs*draft_token]
|
9
|
+
# retrive_index [b, draft_token, depth+2]
|
10
|
+
kernels = cutex.SourceModule(
|
11
|
+
"""
|
12
|
+
//cuda
|
13
|
+
__global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
|
14
|
+
Tensor<bool, 1> tree_mask, Tensor<long, 1> positions, Tensor<long, 3> retrive_index, int topk, int depth, int draft_token_num) {
|
15
|
+
int bid = blockIdx.x;
|
16
|
+
int tid = threadIdx.x;
|
17
|
+
if (tid >= draft_token_num){
|
18
|
+
return;
|
19
|
+
}
|
20
|
+
int seq_tree_idx = draft_token_num * draft_token_num * bid;
|
21
|
+
for(int i=0; i<bid; i++){
|
22
|
+
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
23
|
+
}
|
24
|
+
int seq_len = verified_seq_len[bid];
|
25
|
+
int token_tree_idx = seq_tree_idx + (seq_len+draft_token_num)*tid + seq_len + 1;
|
26
|
+
for(int i=0; i<draft_token_num-1; i++){
|
27
|
+
tree_mask[token_tree_idx+i] = false;
|
28
|
+
}
|
29
|
+
|
30
|
+
int position = 0;
|
31
|
+
if (tid==0){
|
32
|
+
positions[bid*draft_token_num] = seq_len;
|
33
|
+
retrive_index[bid][0][0] = bid * draft_token_num;
|
34
|
+
return;
|
35
|
+
}
|
36
|
+
|
37
|
+
int depends_order[10];
|
38
|
+
|
39
|
+
int cur_position = tid-1;
|
40
|
+
while(true){
|
41
|
+
depends_order[position] = cur_position+1;
|
42
|
+
position += 1;
|
43
|
+
tree_mask[token_tree_idx+cur_position] = true;
|
44
|
+
int parent_tb_idx = selected_index[bid][cur_position]/topk;
|
45
|
+
if(parent_tb_idx==0){
|
46
|
+
break;
|
47
|
+
}
|
48
|
+
|
49
|
+
int token_idx = parent_list[bid][parent_tb_idx];
|
50
|
+
for(cur_position=0; cur_position<draft_token_num;cur_position++){
|
51
|
+
if(selected_index[bid][cur_position]==token_idx){
|
52
|
+
break;
|
53
|
+
}
|
54
|
+
}
|
55
|
+
}
|
56
|
+
positions[bid*draft_token_num+tid] = position + seq_len;
|
57
|
+
|
58
|
+
int is_leaf = 0;
|
59
|
+
for(int i=1;i<draft_token_num;i++){
|
60
|
+
if(tree_mask[seq_tree_idx + i * (draft_token_num+seq_len) + seq_len + tid])
|
61
|
+
{
|
62
|
+
is_leaf ++;
|
63
|
+
}
|
64
|
+
}
|
65
|
+
if(is_leaf==1){
|
66
|
+
for(int i=0; i<position; i++){
|
67
|
+
retrive_index[bid][tid][position-i] = depends_order[i] + bid * draft_token_num;
|
68
|
+
}
|
69
|
+
retrive_index[bid][tid][0] = bid*draft_token_num;
|
70
|
+
}
|
71
|
+
|
72
|
+
|
73
|
+
|
74
|
+
}
|
75
|
+
//!cuda
|
76
|
+
""",
|
77
|
+
float_bits=16, # change to 16 to use half precision as `float` type in the above source code.
|
78
|
+
boundscheck=True, # turning on for debug and off for performance (to use full threads of a block), default is on.
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token):
|
83
|
+
bs = seq_lens.numel()
|
84
|
+
device = parent_list.device
|
85
|
+
tree_mask = torch.full(
|
86
|
+
(torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,),
|
87
|
+
True,
|
88
|
+
device=device,
|
89
|
+
)
|
90
|
+
retrive_index = torch.full(
|
91
|
+
(bs, draft_token, depth + 2), -1, device=device, dtype=torch.long
|
92
|
+
)
|
93
|
+
positions = torch.empty((bs * draft_token,), device=device, dtype=torch.long)
|
94
|
+
|
95
|
+
kernels.build_tree(
|
96
|
+
parent_list,
|
97
|
+
top_score_index,
|
98
|
+
seq_lens.to(torch.int32),
|
99
|
+
tree_mask,
|
100
|
+
positions,
|
101
|
+
retrive_index,
|
102
|
+
topk,
|
103
|
+
depth,
|
104
|
+
draft_token,
|
105
|
+
grid=(bs, 1, 1),
|
106
|
+
block=(64, 1, 1),
|
107
|
+
)
|
108
|
+
index = retrive_index.sum(dim=-1) != -depth - 2
|
109
|
+
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
|
110
|
+
retrive_cum_len = torch.zeros(
|
111
|
+
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
|
112
|
+
)
|
113
|
+
retrive_cum_len[1:] = cum_len
|
114
|
+
retrive_index = retrive_index[index]
|
115
|
+
return tree_mask, positions, retrive_index, retrive_cum_len
|
116
|
+
|
117
|
+
|
118
|
+
if __name__ == "__main__":
|
119
|
+
|
120
|
+
def findp(p_i, index, parent_list):
|
121
|
+
pos = index // 10
|
122
|
+
index_list = index.tolist()
|
123
|
+
parent_list = parent_list.tolist()
|
124
|
+
res = [p_i]
|
125
|
+
while True:
|
126
|
+
p = pos[p_i]
|
127
|
+
if p == 0:
|
128
|
+
break
|
129
|
+
token_idx = parent_list[p]
|
130
|
+
p_i = index_list.index(token_idx)
|
131
|
+
res.append(p_i)
|
132
|
+
return res
|
133
|
+
|
134
|
+
def create_mask(seq_len, draft_token, index, parent_list, max_depth):
|
135
|
+
mask = []
|
136
|
+
positions = []
|
137
|
+
retrive_index = []
|
138
|
+
for i, lens in enumerate(seq_len.tolist()):
|
139
|
+
first_mask = torch.full((lens + draft_token,), True)
|
140
|
+
first_mask[-(draft_token - 1) :] = False
|
141
|
+
positions.append(lens)
|
142
|
+
mask.append(first_mask)
|
143
|
+
seq_order = []
|
144
|
+
first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
|
145
|
+
r_index = [first_index]
|
146
|
+
for j in range(draft_token - 1):
|
147
|
+
mask.append(torch.full((lens + 1,), True))
|
148
|
+
idx = findp(j, index, parent_list)
|
149
|
+
|
150
|
+
seq_order.append(idx)
|
151
|
+
positions.append(len(idx) + seq_len)
|
152
|
+
t = torch.full((draft_token - 1,), False)
|
153
|
+
t[idx] = True
|
154
|
+
mask.append(t)
|
155
|
+
|
156
|
+
for i in range(1, draft_token - 1):
|
157
|
+
is_leaf = 0
|
158
|
+
for j in range(draft_token - 1):
|
159
|
+
if i in seq_order[j]:
|
160
|
+
is_leaf += 1
|
161
|
+
|
162
|
+
if is_leaf == 1:
|
163
|
+
order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
|
164
|
+
for _ in range(max_depth + 1 - len(seq_order[i])):
|
165
|
+
order_list.append(-1)
|
166
|
+
order = torch.Tensor(order_list).cuda().to(torch.long)
|
167
|
+
r_index.append(order)
|
168
|
+
retrive_index.append(torch.stack(r_index))
|
169
|
+
|
170
|
+
return (
|
171
|
+
torch.cat(mask).cuda(),
|
172
|
+
torch.Tensor(positions).cuda().to(torch.long),
|
173
|
+
torch.stack(retrive_index),
|
174
|
+
)
|
175
|
+
|
176
|
+
index = (
|
177
|
+
torch.Tensor(
|
178
|
+
[
|
179
|
+
0,
|
180
|
+
1,
|
181
|
+
2,
|
182
|
+
3,
|
183
|
+
10,
|
184
|
+
11,
|
185
|
+
12,
|
186
|
+
13,
|
187
|
+
20,
|
188
|
+
21,
|
189
|
+
22,
|
190
|
+
30,
|
191
|
+
110,
|
192
|
+
130,
|
193
|
+
150,
|
194
|
+
160,
|
195
|
+
210,
|
196
|
+
211,
|
197
|
+
212,
|
198
|
+
213,
|
199
|
+
214,
|
200
|
+
215,
|
201
|
+
216,
|
202
|
+
217,
|
203
|
+
218,
|
204
|
+
219,
|
205
|
+
220,
|
206
|
+
230,
|
207
|
+
310,
|
208
|
+
311,
|
209
|
+
312,
|
210
|
+
313,
|
211
|
+
314,
|
212
|
+
315,
|
213
|
+
316,
|
214
|
+
317,
|
215
|
+
320,
|
216
|
+
321,
|
217
|
+
322,
|
218
|
+
330,
|
219
|
+
360,
|
220
|
+
380,
|
221
|
+
390,
|
222
|
+
410,
|
223
|
+
411,
|
224
|
+
412,
|
225
|
+
413,
|
226
|
+
414,
|
227
|
+
415,
|
228
|
+
416,
|
229
|
+
417,
|
230
|
+
418,
|
231
|
+
419,
|
232
|
+
420,
|
233
|
+
421,
|
234
|
+
422,
|
235
|
+
423,
|
236
|
+
430,
|
237
|
+
431,
|
238
|
+
440,
|
239
|
+
441,
|
240
|
+
460,
|
241
|
+
470,
|
242
|
+
]
|
243
|
+
)
|
244
|
+
.to(torch.long)
|
245
|
+
.cuda()
|
246
|
+
)
|
247
|
+
|
248
|
+
parent_list = (
|
249
|
+
torch.Tensor(
|
250
|
+
[
|
251
|
+
-1,
|
252
|
+
0,
|
253
|
+
1,
|
254
|
+
2,
|
255
|
+
3,
|
256
|
+
4,
|
257
|
+
5,
|
258
|
+
6,
|
259
|
+
7,
|
260
|
+
8,
|
261
|
+
9,
|
262
|
+
10,
|
263
|
+
11,
|
264
|
+
12,
|
265
|
+
20,
|
266
|
+
30,
|
267
|
+
21,
|
268
|
+
13,
|
269
|
+
22,
|
270
|
+
40,
|
271
|
+
23,
|
272
|
+
110,
|
273
|
+
130,
|
274
|
+
160,
|
275
|
+
150,
|
276
|
+
190,
|
277
|
+
120,
|
278
|
+
111,
|
279
|
+
121,
|
280
|
+
200,
|
281
|
+
180,
|
282
|
+
210,
|
283
|
+
211,
|
284
|
+
212,
|
285
|
+
213,
|
286
|
+
214,
|
287
|
+
215,
|
288
|
+
216,
|
289
|
+
220,
|
290
|
+
230,
|
291
|
+
217,
|
292
|
+
310,
|
293
|
+
311,
|
294
|
+
312,
|
295
|
+
313,
|
296
|
+
320,
|
297
|
+
314,
|
298
|
+
321,
|
299
|
+
315,
|
300
|
+
316,
|
301
|
+
317,
|
302
|
+
]
|
303
|
+
)
|
304
|
+
.to(torch.long)
|
305
|
+
.cuda()
|
306
|
+
)
|
307
|
+
|
308
|
+
verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
|
309
|
+
bs = verified_seq_len.shape[0]
|
310
|
+
topk = 10
|
311
|
+
depth = 5 # depth <= 10
|
312
|
+
draft_token = 64
|
313
|
+
|
314
|
+
tree_mask = torch.full(
|
315
|
+
(
|
316
|
+
torch.sum(verified_seq_len).item() * draft_token
|
317
|
+
+ draft_token * draft_token * bs,
|
318
|
+
),
|
319
|
+
True,
|
320
|
+
).cuda()
|
321
|
+
retrive_index = torch.full(
|
322
|
+
(bs, draft_token, depth + 2), -1, device="cuda", dtype=torch.long
|
323
|
+
)
|
324
|
+
positions = torch.empty((bs * draft_token,), device="cuda", dtype=torch.long)
|
325
|
+
|
326
|
+
kernels.build_tree(
|
327
|
+
parent_list.unsqueeze(0),
|
328
|
+
index.unsqueeze(0),
|
329
|
+
verified_seq_len,
|
330
|
+
tree_mask,
|
331
|
+
positions,
|
332
|
+
retrive_index,
|
333
|
+
topk,
|
334
|
+
depth,
|
335
|
+
draft_token,
|
336
|
+
grid=(bs, 1, 1),
|
337
|
+
block=(64, 1, 1),
|
338
|
+
)
|
339
|
+
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
|
340
|
+
|
341
|
+
c_mask, c_positions, c_retive_index = create_mask(
|
342
|
+
verified_seq_len, draft_token, index, parent_list, depth
|
343
|
+
)
|
344
|
+
|
345
|
+
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
|
346
|
+
assert torch.allclose(positions, c_positions), "positions has error."
|
347
|
+
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
|