sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/bench_serving.py +18 -1
  3. sglang/lang/interpreter.py +71 -1
  4. sglang/lang/ir.py +2 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/chatglm.py +78 -0
  7. sglang/srt/configs/dbrx.py +279 -0
  8. sglang/srt/configs/model_config.py +1 -1
  9. sglang/srt/hf_transformers_utils.py +9 -14
  10. sglang/srt/layers/attention/__init__.py +22 -6
  11. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  12. sglang/srt/layers/attention/flashinfer_backend.py +215 -83
  13. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  14. sglang/srt/layers/attention/triton_backend.py +20 -11
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  16. sglang/srt/layers/linear.py +159 -55
  17. sglang/srt/layers/logits_processor.py +170 -215
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
  41. sglang/srt/layers/parameter.py +431 -0
  42. sglang/srt/layers/quantization/__init__.py +3 -2
  43. sglang/srt/layers/quantization/fp8.py +3 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  45. sglang/srt/layers/sampler.py +57 -21
  46. sglang/srt/layers/torchao_utils.py +17 -3
  47. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  48. sglang/srt/managers/cache_controller.py +307 -0
  49. sglang/srt/managers/data_parallel_controller.py +2 -0
  50. sglang/srt/managers/io_struct.py +1 -2
  51. sglang/srt/managers/schedule_batch.py +33 -3
  52. sglang/srt/managers/schedule_policy.py +159 -90
  53. sglang/srt/managers/scheduler.py +68 -28
  54. sglang/srt/managers/session_controller.py +1 -1
  55. sglang/srt/managers/tokenizer_manager.py +27 -21
  56. sglang/srt/managers/tp_worker.py +16 -4
  57. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  58. sglang/srt/mem_cache/memory_pool.py +206 -1
  59. sglang/srt/metrics/collector.py +22 -30
  60. sglang/srt/model_executor/cuda_graph_runner.py +129 -77
  61. sglang/srt/model_executor/forward_batch_info.py +51 -21
  62. sglang/srt/model_executor/model_runner.py +72 -64
  63. sglang/srt/models/chatglm.py +1 -1
  64. sglang/srt/models/dbrx.py +1 -1
  65. sglang/srt/models/deepseek_v2.py +34 -7
  66. sglang/srt/models/grok.py +109 -29
  67. sglang/srt/models/llama.py +9 -2
  68. sglang/srt/openai_api/adapter.py +0 -17
  69. sglang/srt/openai_api/protocol.py +3 -3
  70. sglang/srt/sampling/sampling_batch_info.py +22 -0
  71. sglang/srt/sampling/sampling_params.py +9 -1
  72. sglang/srt/server.py +20 -13
  73. sglang/srt/server_args.py +120 -58
  74. sglang/srt/speculative/build_eagle_tree.py +347 -0
  75. sglang/srt/speculative/eagle_utils.py +626 -0
  76. sglang/srt/speculative/eagle_worker.py +184 -0
  77. sglang/srt/speculative/spec_info.py +5 -0
  78. sglang/srt/utils.py +47 -7
  79. sglang/test/test_programs.py +23 -1
  80. sglang/test/test_utils.py +36 -7
  81. sglang/version.py +1 -1
  82. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
  83. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
  84. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  85. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  86. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,347 @@
1
+ import cutex
2
+ import torch
3
+
4
+ # parent_table [bs,topk*depth+)]
5
+ # selected_index [bs,draft_token_num-1)]
6
+ # verified_seq_len [bs]
7
+ # tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
8
+ # positions [bs*draft_token]
9
+ # retrive_index [b, draft_token, depth+2]
10
+ kernels = cutex.SourceModule(
11
+ """
12
+ //cuda
13
+ __global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
14
+ Tensor<bool, 1> tree_mask, Tensor<long, 1> positions, Tensor<long, 3> retrive_index, int topk, int depth, int draft_token_num) {
15
+ int bid = blockIdx.x;
16
+ int tid = threadIdx.x;
17
+ if (tid >= draft_token_num){
18
+ return;
19
+ }
20
+ int seq_tree_idx = draft_token_num * draft_token_num * bid;
21
+ for(int i=0; i<bid; i++){
22
+ seq_tree_idx += verified_seq_len[i] * draft_token_num;
23
+ }
24
+ int seq_len = verified_seq_len[bid];
25
+ int token_tree_idx = seq_tree_idx + (seq_len+draft_token_num)*tid + seq_len + 1;
26
+ for(int i=0; i<draft_token_num-1; i++){
27
+ tree_mask[token_tree_idx+i] = false;
28
+ }
29
+
30
+ int position = 0;
31
+ if (tid==0){
32
+ positions[bid*draft_token_num] = seq_len;
33
+ retrive_index[bid][0][0] = bid * draft_token_num;
34
+ return;
35
+ }
36
+
37
+ int depends_order[10];
38
+
39
+ int cur_position = tid-1;
40
+ while(true){
41
+ depends_order[position] = cur_position+1;
42
+ position += 1;
43
+ tree_mask[token_tree_idx+cur_position] = true;
44
+ int parent_tb_idx = selected_index[bid][cur_position]/topk;
45
+ if(parent_tb_idx==0){
46
+ break;
47
+ }
48
+
49
+ int token_idx = parent_list[bid][parent_tb_idx];
50
+ for(cur_position=0; cur_position<draft_token_num;cur_position++){
51
+ if(selected_index[bid][cur_position]==token_idx){
52
+ break;
53
+ }
54
+ }
55
+ }
56
+ positions[bid*draft_token_num+tid] = position + seq_len;
57
+
58
+ int is_leaf = 0;
59
+ for(int i=1;i<draft_token_num;i++){
60
+ if(tree_mask[seq_tree_idx + i * (draft_token_num+seq_len) + seq_len + tid])
61
+ {
62
+ is_leaf ++;
63
+ }
64
+ }
65
+ if(is_leaf==1){
66
+ for(int i=0; i<position; i++){
67
+ retrive_index[bid][tid][position-i] = depends_order[i] + bid * draft_token_num;
68
+ }
69
+ retrive_index[bid][tid][0] = bid*draft_token_num;
70
+ }
71
+
72
+
73
+
74
+ }
75
+ //!cuda
76
+ """,
77
+ float_bits=16, # change to 16 to use half precision as `float` type in the above source code.
78
+ boundscheck=True, # turning on for debug and off for performance (to use full threads of a block), default is on.
79
+ )
80
+
81
+
82
+ def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token):
83
+ bs = seq_lens.numel()
84
+ device = parent_list.device
85
+ tree_mask = torch.full(
86
+ (torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,),
87
+ True,
88
+ device=device,
89
+ )
90
+ retrive_index = torch.full(
91
+ (bs, draft_token, depth + 2), -1, device=device, dtype=torch.long
92
+ )
93
+ positions = torch.empty((bs * draft_token,), device=device, dtype=torch.long)
94
+
95
+ kernels.build_tree(
96
+ parent_list,
97
+ top_score_index,
98
+ seq_lens.to(torch.int32),
99
+ tree_mask,
100
+ positions,
101
+ retrive_index,
102
+ topk,
103
+ depth,
104
+ draft_token,
105
+ grid=(bs, 1, 1),
106
+ block=(64, 1, 1),
107
+ )
108
+ index = retrive_index.sum(dim=-1) != -depth - 2
109
+ cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
110
+ retrive_cum_len = torch.zeros(
111
+ (cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
112
+ )
113
+ retrive_cum_len[1:] = cum_len
114
+ retrive_index = retrive_index[index]
115
+ return tree_mask, positions, retrive_index, retrive_cum_len
116
+
117
+
118
+ if __name__ == "__main__":
119
+
120
+ def findp(p_i, index, parent_list):
121
+ pos = index // 10
122
+ index_list = index.tolist()
123
+ parent_list = parent_list.tolist()
124
+ res = [p_i]
125
+ while True:
126
+ p = pos[p_i]
127
+ if p == 0:
128
+ break
129
+ token_idx = parent_list[p]
130
+ p_i = index_list.index(token_idx)
131
+ res.append(p_i)
132
+ return res
133
+
134
+ def create_mask(seq_len, draft_token, index, parent_list, max_depth):
135
+ mask = []
136
+ positions = []
137
+ retrive_index = []
138
+ for i, lens in enumerate(seq_len.tolist()):
139
+ first_mask = torch.full((lens + draft_token,), True)
140
+ first_mask[-(draft_token - 1) :] = False
141
+ positions.append(lens)
142
+ mask.append(first_mask)
143
+ seq_order = []
144
+ first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
145
+ r_index = [first_index]
146
+ for j in range(draft_token - 1):
147
+ mask.append(torch.full((lens + 1,), True))
148
+ idx = findp(j, index, parent_list)
149
+
150
+ seq_order.append(idx)
151
+ positions.append(len(idx) + seq_len)
152
+ t = torch.full((draft_token - 1,), False)
153
+ t[idx] = True
154
+ mask.append(t)
155
+
156
+ for i in range(1, draft_token - 1):
157
+ is_leaf = 0
158
+ for j in range(draft_token - 1):
159
+ if i in seq_order[j]:
160
+ is_leaf += 1
161
+
162
+ if is_leaf == 1:
163
+ order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
164
+ for _ in range(max_depth + 1 - len(seq_order[i])):
165
+ order_list.append(-1)
166
+ order = torch.Tensor(order_list).cuda().to(torch.long)
167
+ r_index.append(order)
168
+ retrive_index.append(torch.stack(r_index))
169
+
170
+ return (
171
+ torch.cat(mask).cuda(),
172
+ torch.Tensor(positions).cuda().to(torch.long),
173
+ torch.stack(retrive_index),
174
+ )
175
+
176
+ index = (
177
+ torch.Tensor(
178
+ [
179
+ 0,
180
+ 1,
181
+ 2,
182
+ 3,
183
+ 10,
184
+ 11,
185
+ 12,
186
+ 13,
187
+ 20,
188
+ 21,
189
+ 22,
190
+ 30,
191
+ 110,
192
+ 130,
193
+ 150,
194
+ 160,
195
+ 210,
196
+ 211,
197
+ 212,
198
+ 213,
199
+ 214,
200
+ 215,
201
+ 216,
202
+ 217,
203
+ 218,
204
+ 219,
205
+ 220,
206
+ 230,
207
+ 310,
208
+ 311,
209
+ 312,
210
+ 313,
211
+ 314,
212
+ 315,
213
+ 316,
214
+ 317,
215
+ 320,
216
+ 321,
217
+ 322,
218
+ 330,
219
+ 360,
220
+ 380,
221
+ 390,
222
+ 410,
223
+ 411,
224
+ 412,
225
+ 413,
226
+ 414,
227
+ 415,
228
+ 416,
229
+ 417,
230
+ 418,
231
+ 419,
232
+ 420,
233
+ 421,
234
+ 422,
235
+ 423,
236
+ 430,
237
+ 431,
238
+ 440,
239
+ 441,
240
+ 460,
241
+ 470,
242
+ ]
243
+ )
244
+ .to(torch.long)
245
+ .cuda()
246
+ )
247
+
248
+ parent_list = (
249
+ torch.Tensor(
250
+ [
251
+ -1,
252
+ 0,
253
+ 1,
254
+ 2,
255
+ 3,
256
+ 4,
257
+ 5,
258
+ 6,
259
+ 7,
260
+ 8,
261
+ 9,
262
+ 10,
263
+ 11,
264
+ 12,
265
+ 20,
266
+ 30,
267
+ 21,
268
+ 13,
269
+ 22,
270
+ 40,
271
+ 23,
272
+ 110,
273
+ 130,
274
+ 160,
275
+ 150,
276
+ 190,
277
+ 120,
278
+ 111,
279
+ 121,
280
+ 200,
281
+ 180,
282
+ 210,
283
+ 211,
284
+ 212,
285
+ 213,
286
+ 214,
287
+ 215,
288
+ 216,
289
+ 220,
290
+ 230,
291
+ 217,
292
+ 310,
293
+ 311,
294
+ 312,
295
+ 313,
296
+ 320,
297
+ 314,
298
+ 321,
299
+ 315,
300
+ 316,
301
+ 317,
302
+ ]
303
+ )
304
+ .to(torch.long)
305
+ .cuda()
306
+ )
307
+
308
+ verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
309
+ bs = verified_seq_len.shape[0]
310
+ topk = 10
311
+ depth = 5 # depth <= 10
312
+ draft_token = 64
313
+
314
+ tree_mask = torch.full(
315
+ (
316
+ torch.sum(verified_seq_len).item() * draft_token
317
+ + draft_token * draft_token * bs,
318
+ ),
319
+ True,
320
+ ).cuda()
321
+ retrive_index = torch.full(
322
+ (bs, draft_token, depth + 2), -1, device="cuda", dtype=torch.long
323
+ )
324
+ positions = torch.empty((bs * draft_token,), device="cuda", dtype=torch.long)
325
+
326
+ kernels.build_tree(
327
+ parent_list.unsqueeze(0),
328
+ index.unsqueeze(0),
329
+ verified_seq_len,
330
+ tree_mask,
331
+ positions,
332
+ retrive_index,
333
+ topk,
334
+ depth,
335
+ draft_token,
336
+ grid=(bs, 1, 1),
337
+ block=(64, 1, 1),
338
+ )
339
+ retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
340
+
341
+ c_mask, c_positions, c_retive_index = create_mask(
342
+ verified_seq_len, draft_token, index, parent_list, depth
343
+ )
344
+
345
+ assert torch.allclose(tree_mask, c_mask), "tree mask has error."
346
+ assert torch.allclose(positions, c_positions), "positions has error."
347
+ assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."