sglang 0.4.2.post2__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. sglang/check_env.py +1 -0
  2. sglang/srt/constrained/outlines_backend.py +4 -1
  3. sglang/srt/function_call_parser.py +96 -69
  4. sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
  5. sglang/srt/layers/attention/flashinfer_backend.py +34 -41
  6. sglang/srt/layers/attention/triton_backend.py +64 -16
  7. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
  8. sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
  9. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -5
  10. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  11. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  12. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  13. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  15. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  20. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  21. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  22. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/quantization/fp8_kernel.py +43 -10
  24. sglang/srt/lora/backend/__init__.py +25 -5
  25. sglang/srt/lora/backend/base_backend.py +31 -9
  26. sglang/srt/lora/backend/flashinfer_backend.py +41 -4
  27. sglang/srt/lora/backend/triton_backend.py +34 -4
  28. sglang/srt/lora/layers.py +293 -0
  29. sglang/srt/lora/lora.py +101 -326
  30. sglang/srt/lora/lora_manager.py +101 -269
  31. sglang/srt/lora/mem_pool.py +174 -0
  32. sglang/srt/lora/triton_ops/__init__.py +7 -1
  33. sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
  34. sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
  35. sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
  36. sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
  37. sglang/srt/lora/utils.py +141 -0
  38. sglang/srt/model_executor/cuda_graph_runner.py +4 -0
  39. sglang/srt/models/llama.py +8 -3
  40. sglang/srt/speculative/build_eagle_tree.py +482 -102
  41. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  42. sglang/srt/speculative/eagle_utils.py +134 -61
  43. sglang/srt/speculative/eagle_worker.py +1 -0
  44. sglang/version.py +1 -1
  45. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
  46. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +49 -32
  47. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
  48. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
  49. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
@@ -1,124 +1,175 @@
1
- import cutex
1
+ # NOTE: Please run this file to make sure the test cases are correct.
2
+
3
+ from typing import List
4
+
2
5
  import torch
3
6
 
4
- # parent_table [bs,topk*depth+)]
5
- # selected_index [bs,draft_token_num-1)]
6
- # verified_seq_len [bs]
7
- # tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
8
- # positions [bs*draft_token]
9
- # retrive_index [b, draft_token, depth+2]
10
- kernels = cutex.SourceModule(
11
- """
12
- //cuda
13
- __global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
14
- Tensor<bool, 1> tree_mask, Tensor<long, 1> positions, Tensor<long, 3> retrive_index, int topk, int depth, int draft_token_num) {
15
- int bid = blockIdx.x;
16
- int tid = threadIdx.x;
17
- if (tid >= draft_token_num){
18
- return;
19
- }
20
- int seq_tree_idx = draft_token_num * draft_token_num * bid;
21
- for(int i=0; i<bid; i++){
22
- seq_tree_idx += verified_seq_len[i] * draft_token_num;
23
- }
24
- int seq_len = verified_seq_len[bid];
25
- int token_tree_idx = seq_tree_idx + (seq_len+draft_token_num)*tid + seq_len + 1;
26
- for(int i=0; i<draft_token_num-1; i++){
27
- tree_mask[token_tree_idx+i] = false;
28
- }
29
-
30
- int position = 0;
31
- if (tid==0){
32
- positions[bid*draft_token_num] = seq_len;
33
- retrive_index[bid][0][0] = bid * draft_token_num;
34
- return;
35
- }
36
-
37
- int depends_order[10];
38
-
39
- int cur_position = tid-1;
40
- while(true){
41
- depends_order[position] = cur_position+1;
42
- position += 1;
43
- tree_mask[token_tree_idx+cur_position] = true;
44
- int parent_tb_idx = selected_index[bid][cur_position]/topk;
45
- if(parent_tb_idx==0){
46
- break;
47
- }
48
-
49
- int token_idx = parent_list[bid][parent_tb_idx];
50
- for(cur_position=0; cur_position<draft_token_num;cur_position++){
51
- if(selected_index[bid][cur_position]==token_idx){
52
- break;
53
- }
54
- }
55
- }
56
- positions[bid*draft_token_num+tid] = position + seq_len;
57
-
58
- int is_leaf = 0;
59
- for(int i=1;i<draft_token_num;i++){
60
- if(tree_mask[seq_tree_idx + i * (draft_token_num+seq_len) + seq_len + tid])
61
- {
62
- is_leaf ++;
63
- }
64
- }
65
- if(is_leaf==1){
66
- for(int i=0; i<position; i++){
67
- retrive_index[bid][tid][position-i] = depends_order[i] + bid * draft_token_num;
68
- }
69
- retrive_index[bid][tid][0] = bid*draft_token_num;
70
- }
71
-
72
-
73
-
74
- }
75
- //!cuda
76
- """,
77
- float_bits=16, # change to 16 to use half precision as `float` type in the above source code.
78
- boundscheck=True, # turning on for debug and off for performance (to use full threads of a block), default is on.
79
- )
7
+ from sglang.srt.utils import is_cuda_available
8
+
9
+ if is_cuda_available():
10
+ from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
11
+ from sgl_kernel import (
12
+ build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
13
+ )
14
+
15
+
16
+ def build_tree_kernel_efficient_preprocess(
17
+ verified_id: torch.Tensor,
18
+ score_list: List[torch.Tensor],
19
+ token_list: List[torch.Tensor],
20
+ parents_list: List[torch.Tensor],
21
+ num_verify_tokens: int,
22
+ ):
23
+ score_list = torch.cat(score_list, dim=1).flatten(
24
+ 1
25
+ ) # b, n, topk; n= 1 + (num_steps-1) * self.topk
26
+ ss_token_list = torch.cat(
27
+ token_list, dim=1
28
+ ) # b, (self.topk + (num_steps-1) * self.topk)
29
+ top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
30
+ top_scores_index = top_scores.indices
31
+ top_scores_index = torch.sort(top_scores_index).values
32
+
33
+ draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
34
+ draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
35
+ parent_list = torch.cat(parents_list[:-1], dim=1)
36
+
37
+ return parent_list, top_scores_index, draft_tokens
38
+
39
+
40
+ def build_tree_kernel_efficient(
41
+ verified_id: torch.Tensor,
42
+ score_list: List[torch.Tensor],
43
+ token_list: List[torch.Tensor],
44
+ parents_list: List[torch.Tensor],
45
+ seq_lens: torch.Tensor,
46
+ seq_lens_sum: int,
47
+ topk: int,
48
+ spec_steps: int,
49
+ num_verify_tokens: int,
50
+ ):
51
+ parent_list, top_scores_index, draft_tokens = (
52
+ build_tree_kernel_efficient_preprocess(
53
+ verified_id,
54
+ score_list,
55
+ token_list,
56
+ parents_list,
57
+ num_verify_tokens,
58
+ )
59
+ )
60
+
61
+ # seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
62
+ bs = seq_lens.numel()
63
+ device = seq_lens.device
64
+ # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
65
+ # where each row indicates the attending pattern of each draft token
66
+ # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
67
+ tree_mask = torch.full(
68
+ (
69
+ seq_lens_sum * num_verify_tokens
70
+ + num_verify_tokens * num_verify_tokens * bs,
71
+ ),
72
+ True,
73
+ device=device,
74
+ )
75
+ retrive_index = torch.full(
76
+ (bs, num_verify_tokens), -1, device=device, dtype=torch.long
77
+ )
78
+ retrive_next_token = torch.full(
79
+ (bs, num_verify_tokens), -1, device=device, dtype=torch.long
80
+ )
81
+ retrive_next_sibling = torch.full(
82
+ (bs, num_verify_tokens), -1, device=device, dtype=torch.long
83
+ )
84
+ # position: where each token belongs to
85
+ # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
86
+ # then, positions = [7, 8, 8, 9]
87
+ positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
88
+
89
+ sgl_build_tree_kernel_efficient(
90
+ parent_list,
91
+ top_scores_index,
92
+ seq_lens.to(torch.int32),
93
+ tree_mask,
94
+ positions,
95
+ retrive_index,
96
+ retrive_next_token,
97
+ retrive_next_sibling,
98
+ topk,
99
+ spec_steps,
100
+ num_verify_tokens,
101
+ )
102
+ return (
103
+ tree_mask,
104
+ positions,
105
+ retrive_index,
106
+ retrive_next_token,
107
+ retrive_next_sibling,
108
+ draft_tokens,
109
+ )
80
110
 
81
111
 
82
112
  def build_tree_kernel(
83
- parent_list, top_score_index, seq_lens, seq_lens_sum, topk, depth, draft_token
113
+ verified_id: torch.Tensor,
114
+ score_list: List[torch.Tensor],
115
+ token_list: List[torch.Tensor],
116
+ parents_list: List[torch.Tensor],
117
+ seq_lens: torch.Tensor,
118
+ seq_lens_sum: int,
119
+ topk: int,
120
+ spec_steps: int,
121
+ num_verify_tokens: int,
84
122
  ):
123
+ parent_list, top_scores_index, draft_tokens = (
124
+ build_tree_kernel_efficient_preprocess(
125
+ verified_id,
126
+ score_list,
127
+ token_list,
128
+ parents_list,
129
+ num_verify_tokens,
130
+ )
131
+ )
132
+
85
133
  bs = seq_lens.numel()
86
- device = parent_list.device
134
+ device = seq_lens.device
135
+
87
136
  tree_mask = torch.full(
88
- (seq_lens_sum * draft_token + draft_token * draft_token * bs,),
137
+ (
138
+ seq_lens_sum * num_verify_tokens
139
+ + num_verify_tokens * num_verify_tokens * bs,
140
+ ),
89
141
  True,
90
142
  device=device,
91
143
  )
92
144
  retrive_index = torch.full(
93
- (bs, draft_token, depth + 2), -1, device=device, dtype=torch.long
145
+ (bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long
94
146
  )
95
- positions = torch.empty((bs * draft_token,), device=device, dtype=torch.long)
147
+ positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
96
148
 
97
- kernels.build_tree(
149
+ sgl_build_tree_kernel(
98
150
  parent_list,
99
- top_score_index,
151
+ top_scores_index,
100
152
  seq_lens.to(torch.int32),
101
153
  tree_mask,
102
154
  positions,
103
155
  retrive_index,
104
156
  topk,
105
- depth,
106
- draft_token,
107
- grid=(bs, 1, 1),
108
- block=(64, 1, 1),
157
+ spec_steps,
158
+ num_verify_tokens,
109
159
  )
110
- index = retrive_index.sum(dim=-1) != -depth - 2
160
+
161
+ index = retrive_index.sum(dim=-1) != -spec_steps - 2
111
162
  cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
112
163
  retrive_cum_len = torch.zeros(
113
164
  (cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
114
165
  )
115
166
  retrive_cum_len[1:] = cum_len
167
+ # TODO: this indexing cause a synchronization, optimize this
116
168
  retrive_index = retrive_index[index]
117
- return tree_mask, positions, retrive_index, retrive_cum_len
169
+ return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens
118
170
 
119
171
 
120
- if __name__ == "__main__":
121
-
172
+ def test_build_tree_kernel():
122
173
  def findp(p_i, index, parent_list):
123
174
  pos = index // 10
124
175
  index_list = index.tolist()
@@ -311,21 +362,21 @@ if __name__ == "__main__":
311
362
  bs = verified_seq_len.shape[0]
312
363
  topk = 10
313
364
  depth = 5 # depth <= 10
314
- draft_token = 64
365
+ num_draft_token = 64
315
366
 
316
367
  tree_mask = torch.full(
317
368
  (
318
- torch.sum(verified_seq_len).item() * draft_token
319
- + draft_token * draft_token * bs,
369
+ torch.sum(verified_seq_len).item() * num_draft_token
370
+ + num_draft_token * num_draft_token * bs,
320
371
  ),
321
372
  True,
322
373
  ).cuda()
323
374
  retrive_index = torch.full(
324
- (bs, draft_token, depth + 2), -1, device="cuda", dtype=torch.long
375
+ (bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long
325
376
  )
326
- positions = torch.empty((bs * draft_token,), device="cuda", dtype=torch.long)
377
+ positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long)
327
378
 
328
- kernels.build_tree(
379
+ sgl_build_tree_kernel(
329
380
  parent_list.unsqueeze(0),
330
381
  index.unsqueeze(0),
331
382
  verified_seq_len,
@@ -334,16 +385,345 @@ if __name__ == "__main__":
334
385
  retrive_index,
335
386
  topk,
336
387
  depth,
337
- draft_token,
338
- grid=(bs, 1, 1),
339
- block=(64, 1, 1),
388
+ num_draft_token,
340
389
  )
390
+
341
391
  retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
342
392
 
343
393
  c_mask, c_positions, c_retive_index = create_mask(
344
- verified_seq_len, draft_token, index, parent_list, depth
394
+ verified_seq_len, num_draft_token, index, parent_list, depth
345
395
  )
346
396
 
347
397
  assert torch.allclose(tree_mask, c_mask), "tree mask has error."
348
398
  assert torch.allclose(positions, c_positions), "positions has error."
349
399
  assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
400
+
401
+
402
+ def test_build_tree_kernel_efficient():
403
+ verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
404
+ score_list = [
405
+ torch.tensor(
406
+ [
407
+ [[7.1127e-01, 2.8292e-01, 2.2995e-03, 1.7357e-03]],
408
+ [[9.7476e-01, 2.2219e-02, 6.5031e-04, 1.3212e-04]],
409
+ ],
410
+ dtype=torch.float32,
411
+ device="cuda",
412
+ ),
413
+ torch.tensor(
414
+ [
415
+ [
416
+ [6.9142e-01, 1.2863e-02, 1.6873e-03, 1.1871e-03],
417
+ [2.4787e-01, 1.8818e-02, 1.4204e-02, 9.2235e-04],
418
+ [2.2971e-03, 1.6700e-06, 1.8737e-07, 8.3146e-08],
419
+ [1.2771e-03, 2.4374e-04, 1.7832e-04, 1.1947e-05],
420
+ ],
421
+ [
422
+ [8.4832e-02, 6.6068e-02, 5.8304e-02, 5.7851e-02],
423
+ [2.3616e-03, 1.1243e-03, 5.4368e-04, 2.7768e-04],
424
+ [2.5286e-04, 1.5578e-04, 2.8817e-05, 1.2888e-05],
425
+ [1.2834e-04, 2.5417e-06, 1.1279e-06, 1.6088e-08],
426
+ ],
427
+ ],
428
+ dtype=torch.float32,
429
+ device="cuda",
430
+ ),
431
+ torch.tensor(
432
+ [
433
+ [
434
+ [6.6438e-01, 2.6997e-02, 2.4236e-05, 4.0821e-06],
435
+ [2.4402e-01, 2.8409e-03, 5.0935e-04, 2.9022e-04],
436
+ [1.6178e-02, 2.0567e-03, 4.5892e-04, 3.0034e-05],
437
+ [1.3023e-02, 5.0497e-04, 3.6371e-04, 8.7750e-05],
438
+ ],
439
+ [
440
+ [2.3263e-02, 2.0054e-02, 9.3990e-03, 2.7783e-03],
441
+ [6.4156e-02, 5.5506e-04, 1.0429e-04, 9.7211e-05],
442
+ [4.9950e-02, 5.0630e-03, 9.0068e-04, 3.3656e-04],
443
+ [7.5817e-03, 8.5731e-04, 6.9972e-04, 6.0793e-04],
444
+ ],
445
+ ],
446
+ dtype=torch.float32,
447
+ device="cuda",
448
+ ),
449
+ torch.tensor(
450
+ [
451
+ [
452
+ [6.6420e-01, 1.0525e-04, 6.5864e-05, 1.2253e-06],
453
+ [1.3019e-01, 1.0461e-01, 5.2083e-03, 1.6777e-03],
454
+ [2.0103e-02, 6.7335e-03, 1.2625e-04, 1.0364e-05],
455
+ [1.5142e-02, 7.0819e-04, 9.6595e-05, 8.7951e-05],
456
+ ],
457
+ [
458
+ [5.8608e-02, 1.8840e-03, 7.8535e-04, 4.4400e-04],
459
+ [1.2185e-02, 2.0684e-03, 1.7418e-03, 1.4327e-03],
460
+ [6.2455e-03, 6.1487e-03, 2.6862e-03, 1.8034e-03],
461
+ [1.8590e-03, 1.6151e-03, 1.2481e-03, 3.6038e-04],
462
+ ],
463
+ ],
464
+ dtype=torch.float32,
465
+ device="cuda",
466
+ ),
467
+ ]
468
+ token_list = [
469
+ torch.tensor(
470
+ [[29896, 29906, 29900, 29945], [13, 2, 29871, 28956]],
471
+ dtype=torch.int64,
472
+ device="cuda",
473
+ ),
474
+ torch.tensor(
475
+ [
476
+ [
477
+ 29889,
478
+ 29974,
479
+ 29945,
480
+ 29900,
481
+ 29974,
482
+ 29922,
483
+ 29930,
484
+ 29958,
485
+ 29889,
486
+ 29974,
487
+ 29930,
488
+ 29945,
489
+ 29974,
490
+ 29922,
491
+ 29930,
492
+ 29958,
493
+ ],
494
+ [
495
+ 22550,
496
+ 4136,
497
+ 16492,
498
+ 8439,
499
+ 29871,
500
+ 2,
501
+ 3001,
502
+ 13,
503
+ 2,
504
+ 13,
505
+ 29906,
506
+ 29946,
507
+ 2,
508
+ 13,
509
+ 29871,
510
+ 259,
511
+ ],
512
+ ],
513
+ device="cuda",
514
+ ),
515
+ torch.tensor(
516
+ [
517
+ [
518
+ 29946,
519
+ 29945,
520
+ 29953,
521
+ 29906,
522
+ 29896,
523
+ 29945,
524
+ 29900,
525
+ 29906,
526
+ 29896,
527
+ 29945,
528
+ 29906,
529
+ 29953,
530
+ 29896,
531
+ 29945,
532
+ 29906,
533
+ 29946,
534
+ ],
535
+ [
536
+ 29871,
537
+ 2,
538
+ 29901,
539
+ 29889,
540
+ 29871,
541
+ 2,
542
+ 395,
543
+ 259,
544
+ 29901,
545
+ 29871,
546
+ 2,
547
+ 29889,
548
+ 3001,
549
+ 1234,
550
+ 7146,
551
+ 2186,
552
+ ],
553
+ ],
554
+ device="cuda",
555
+ ),
556
+ torch.tensor(
557
+ [
558
+ [
559
+ 29946,
560
+ 29974,
561
+ 29945,
562
+ 29930,
563
+ 29889,
564
+ 29922,
565
+ 29974,
566
+ 29930,
567
+ 29974,
568
+ 29946,
569
+ 29930,
570
+ 29922,
571
+ 29889,
572
+ 29974,
573
+ 29945,
574
+ 29922,
575
+ ],
576
+ [
577
+ 29941,
578
+ 29906,
579
+ 2,
580
+ 29946,
581
+ 29871,
582
+ 450,
583
+ 319,
584
+ 14990,
585
+ 29946,
586
+ 29941,
587
+ 2,
588
+ 29906,
589
+ 29871,
590
+ 2,
591
+ 3001,
592
+ 13,
593
+ ],
594
+ ],
595
+ device="cuda",
596
+ ),
597
+ ]
598
+ parents_list = [
599
+ torch.tensor(
600
+ [[-1, 0, 1, 2, 3], [-1, 0, 1, 2, 3]], dtype=torch.int64, device="cuda"
601
+ ),
602
+ torch.tensor([[4, 8, 9, 10], [4, 5, 6, 7]], dtype=torch.int64, device="cuda"),
603
+ torch.tensor(
604
+ [[20, 24, 21, 28], [24, 28, 20, 21]], dtype=torch.int64, device="cuda"
605
+ ),
606
+ torch.tensor(
607
+ [[36, 40, 41, 44], [36, 40, 44, 45]], dtype=torch.int64, device="cuda"
608
+ ),
609
+ ]
610
+ seq_lens = torch.tensor([5, 10], dtype=torch.int64, device="cuda")
611
+ topk = 4
612
+ depth = 4
613
+ num_draft_token = 8
614
+
615
+ tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
616
+ build_tree_kernel(
617
+ verified_id=verified_id,
618
+ score_list=score_list,
619
+ token_list=token_list,
620
+ parents_list=parents_list,
621
+ seq_lens=seq_lens,
622
+ seq_lens_sum=torch.sum(seq_lens).item(),
623
+ topk=topk,
624
+ spec_steps=depth,
625
+ num_verify_tokens=num_draft_token,
626
+ )
627
+ )
628
+
629
+ from sglang.srt.utils import first_rank_print
630
+
631
+ first_rank_print("=========== build tree kernel ==========")
632
+ # first_rank_print(f"{tree_mask=}", flush=True)
633
+ first_rank_print(f"{position=}", flush=True)
634
+ first_rank_print(f"{retrive_index=}", flush=True)
635
+ first_rank_print(f"{retrive_cum_len=}", flush=True)
636
+ first_rank_print(f"{draft_tokens=}", flush=True)
637
+ assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
638
+ assert retrive_index.tolist() == [
639
+ [0, -1, -1, -1, -1, -1],
640
+ [0, 2, 4, 6, -1, -1],
641
+ [0, 1, 3, 5, 7, -1],
642
+ [8, -1, -1, -1, -1, -1],
643
+ [8, 9, 10, -1, -1, -1],
644
+ [8, 9, 12, -1, -1, -1],
645
+ [8, 9, 13, -1, -1, -1],
646
+ [8, 9, 11, 14, 15, -1],
647
+ ]
648
+ assert retrive_cum_len.tolist() == [0, 3, 8]
649
+ assert draft_tokens.tolist() == [
650
+ 29974,
651
+ 29896,
652
+ 29906,
653
+ 29889,
654
+ 29974,
655
+ 29946,
656
+ 29896,
657
+ 29946,
658
+ 13,
659
+ 13,
660
+ 22550,
661
+ 4136,
662
+ 16492,
663
+ 8439,
664
+ 29871,
665
+ 29941,
666
+ ]
667
+
668
+ (
669
+ tree_mask,
670
+ position,
671
+ retrive_index,
672
+ retrive_next_token,
673
+ retrive_next_sibling,
674
+ draft_tokens,
675
+ ) = build_tree_kernel_efficient(
676
+ verified_id=verified_id,
677
+ score_list=score_list,
678
+ token_list=token_list,
679
+ parents_list=parents_list,
680
+ seq_lens=seq_lens,
681
+ seq_lens_sum=torch.sum(seq_lens).item(),
682
+ topk=topk,
683
+ spec_steps=depth,
684
+ num_verify_tokens=num_draft_token,
685
+ )
686
+
687
+ first_rank_print("=========== build tree kernel efficient ==========")
688
+ # first_rank_print(f"{tree_mask=}", flush=True)
689
+ first_rank_print(f"{position=}", flush=True)
690
+ first_rank_print(f"{retrive_index=}", flush=True)
691
+ first_rank_print(f"{retrive_next_token=}", flush=True)
692
+ first_rank_print(f"{retrive_next_sibling=}", flush=True)
693
+ first_rank_print(f"{draft_tokens=}", flush=True)
694
+ assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
695
+ assert retrive_index.tolist() == [
696
+ [0, 1, 2, 3, 4, 5, 6, 7],
697
+ [8, 9, 10, 11, 12, 13, 14, 15],
698
+ ]
699
+ assert retrive_next_token.tolist() == [
700
+ [1, 3, 4, 5, 6, 7, -1, -1],
701
+ [1, 2, -1, 6, -1, -1, 7, -1],
702
+ ]
703
+ assert retrive_next_sibling.tolist() == [
704
+ [-1, 2, -1, -1, -1, -1, -1, -1],
705
+ [-1, -1, 3, 4, 5, -1, -1, -1],
706
+ ]
707
+ assert draft_tokens.tolist() == [
708
+ 29974,
709
+ 29896,
710
+ 29906,
711
+ 29889,
712
+ 29974,
713
+ 29946,
714
+ 29896,
715
+ 29946,
716
+ 13,
717
+ 13,
718
+ 22550,
719
+ 4136,
720
+ 16492,
721
+ 8439,
722
+ 29871,
723
+ 29941,
724
+ ]
725
+
726
+
727
+ if __name__ == "__main__":
728
+ test_build_tree_kernel_efficient()
729
+ test_build_tree_kernel()
@@ -85,6 +85,7 @@ class EAGLEDraftCudaGraphRunner:
85
85
  "1. disable cuda graph by --disable-cuda-graph\n"
86
86
  "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
87
87
  "3. disable torch compile by not using --enable-torch-compile\n"
88
+ "4. specify --dtype to the same dtype (e.g. bfloat16)\n"
88
89
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
89
90
  )
90
91