sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/function_call_parser.py +96 -69
  5. sglang/srt/layers/activation.py +10 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
  7. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  8. sglang/srt/layers/attention/triton_backend.py +124 -12
  9. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
  12. sglang/srt/layers/layernorm.py +1 -5
  13. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -13
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  24. sglang/srt/layers/moe/topk.py +4 -0
  25. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  46. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/fp8_kernel.py +173 -2
  48. sglang/srt/layers/rotary_embedding.py +1 -3
  49. sglang/srt/layers/sampler.py +4 -4
  50. sglang/srt/lora/backend/__init__.py +8 -0
  51. sglang/srt/lora/backend/base_backend.py +95 -0
  52. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  53. sglang/srt/lora/backend/triton_backend.py +61 -0
  54. sglang/srt/lora/lora.py +127 -112
  55. sglang/srt/lora/lora_manager.py +50 -18
  56. sglang/srt/lora/triton_ops/__init__.py +5 -0
  57. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  59. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  60. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  61. sglang/srt/model_executor/forward_batch_info.py +58 -59
  62. sglang/srt/model_executor/model_runner.py +2 -2
  63. sglang/srt/models/llama.py +8 -3
  64. sglang/srt/models/qwen2_vl.py +1 -1
  65. sglang/srt/server_args.py +13 -2
  66. sglang/srt/speculative/build_eagle_tree.py +486 -104
  67. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  68. sglang/srt/speculative/eagle_utils.py +420 -401
  69. sglang/srt/speculative/eagle_worker.py +177 -45
  70. sglang/srt/utils.py +7 -0
  71. sglang/test/runners.py +2 -0
  72. sglang/version.py +1 -1
  73. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +15 -6
  74. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +77 -38
  75. sglang/srt/layers/custom_op_util.py +0 -25
  76. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -1,122 +1,175 @@
1
- import cutex
1
+ # NOTE: Please run this file to make sure the test cases are correct.
2
+
3
+ from typing import List
4
+
2
5
  import torch
3
6
 
4
- # parent_table [bs,topk*depth+)]
5
- # selected_index [bs,draft_token_num-1)]
6
- # verified_seq_len [bs]
7
- # tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
8
- # positions [bs*draft_token]
9
- # retrive_index [b, draft_token, depth+2]
10
- kernels = cutex.SourceModule(
11
- """
12
- //cuda
13
- __global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
14
- Tensor<bool, 1> tree_mask, Tensor<long, 1> positions, Tensor<long, 3> retrive_index, int topk, int depth, int draft_token_num) {
15
- int bid = blockIdx.x;
16
- int tid = threadIdx.x;
17
- if (tid >= draft_token_num){
18
- return;
19
- }
20
- int seq_tree_idx = draft_token_num * draft_token_num * bid;
21
- for(int i=0; i<bid; i++){
22
- seq_tree_idx += verified_seq_len[i] * draft_token_num;
23
- }
24
- int seq_len = verified_seq_len[bid];
25
- int token_tree_idx = seq_tree_idx + (seq_len+draft_token_num)*tid + seq_len + 1;
26
- for(int i=0; i<draft_token_num-1; i++){
27
- tree_mask[token_tree_idx+i] = false;
28
- }
29
-
30
- int position = 0;
31
- if (tid==0){
32
- positions[bid*draft_token_num] = seq_len;
33
- retrive_index[bid][0][0] = bid * draft_token_num;
34
- return;
35
- }
36
-
37
- int depends_order[10];
38
-
39
- int cur_position = tid-1;
40
- while(true){
41
- depends_order[position] = cur_position+1;
42
- position += 1;
43
- tree_mask[token_tree_idx+cur_position] = true;
44
- int parent_tb_idx = selected_index[bid][cur_position]/topk;
45
- if(parent_tb_idx==0){
46
- break;
47
- }
48
-
49
- int token_idx = parent_list[bid][parent_tb_idx];
50
- for(cur_position=0; cur_position<draft_token_num;cur_position++){
51
- if(selected_index[bid][cur_position]==token_idx){
52
- break;
53
- }
54
- }
55
- }
56
- positions[bid*draft_token_num+tid] = position + seq_len;
57
-
58
- int is_leaf = 0;
59
- for(int i=1;i<draft_token_num;i++){
60
- if(tree_mask[seq_tree_idx + i * (draft_token_num+seq_len) + seq_len + tid])
61
- {
62
- is_leaf ++;
63
- }
64
- }
65
- if(is_leaf==1){
66
- for(int i=0; i<position; i++){
67
- retrive_index[bid][tid][position-i] = depends_order[i] + bid * draft_token_num;
68
- }
69
- retrive_index[bid][tid][0] = bid*draft_token_num;
70
- }
71
-
72
-
73
-
74
- }
75
- //!cuda
76
- """,
77
- float_bits=16, # change to 16 to use half precision as `float` type in the above source code.
78
- boundscheck=True, # turning on for debug and off for performance (to use full threads of a block), default is on.
79
- )
80
-
81
-
82
- def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token):
7
+ from sglang.srt.utils import is_cuda_available
8
+
9
+ if is_cuda_available():
10
+ from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
11
+ from sgl_kernel import (
12
+ build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
13
+ )
14
+
15
+
16
+ def build_tree_kernel_efficient_preprocess(
17
+ verified_id: torch.Tensor,
18
+ score_list: List[torch.Tensor],
19
+ token_list: List[torch.Tensor],
20
+ parents_list: List[torch.Tensor],
21
+ num_verify_tokens: int,
22
+ ):
23
+ score_list = torch.cat(score_list, dim=1).flatten(
24
+ 1
25
+ ) # b, n, topk; n= 1 + (num_steps-1) * self.topk
26
+ ss_token_list = torch.cat(
27
+ token_list, dim=1
28
+ ) # b, (self.topk + (num_steps-1) * self.topk)
29
+ top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
30
+ top_scores_index = top_scores.indices
31
+ top_scores_index = torch.sort(top_scores_index).values
32
+
33
+ draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
34
+ draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
35
+ parent_list = torch.cat(parents_list[:-1], dim=1)
36
+
37
+ return parent_list, top_scores_index, draft_tokens
38
+
39
+
40
+ def build_tree_kernel_efficient(
41
+ verified_id: torch.Tensor,
42
+ score_list: List[torch.Tensor],
43
+ token_list: List[torch.Tensor],
44
+ parents_list: List[torch.Tensor],
45
+ seq_lens: torch.Tensor,
46
+ seq_lens_sum: int,
47
+ topk: int,
48
+ spec_steps: int,
49
+ num_verify_tokens: int,
50
+ ):
51
+ parent_list, top_scores_index, draft_tokens = (
52
+ build_tree_kernel_efficient_preprocess(
53
+ verified_id,
54
+ score_list,
55
+ token_list,
56
+ parents_list,
57
+ num_verify_tokens,
58
+ )
59
+ )
60
+
61
+ # seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
83
62
  bs = seq_lens.numel()
84
- device = parent_list.device
63
+ device = seq_lens.device
64
+ # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
65
+ # where each row indicates the attending pattern of each draft token
66
+ # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
85
67
  tree_mask = torch.full(
86
- (torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,),
68
+ (
69
+ seq_lens_sum * num_verify_tokens
70
+ + num_verify_tokens * num_verify_tokens * bs,
71
+ ),
87
72
  True,
88
73
  device=device,
89
74
  )
90
75
  retrive_index = torch.full(
91
- (bs, draft_token, depth + 2), -1, device=device, dtype=torch.long
76
+ (bs, num_verify_tokens), -1, device=device, dtype=torch.long
77
+ )
78
+ retrive_next_token = torch.full(
79
+ (bs, num_verify_tokens), -1, device=device, dtype=torch.long
92
80
  )
93
- positions = torch.empty((bs * draft_token,), device=device, dtype=torch.long)
81
+ retrive_next_sibling = torch.full(
82
+ (bs, num_verify_tokens), -1, device=device, dtype=torch.long
83
+ )
84
+ # position: where each token belongs to
85
+ # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
86
+ # then, positions = [7, 8, 8, 9]
87
+ positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
94
88
 
95
- kernels.build_tree(
89
+ sgl_build_tree_kernel_efficient(
96
90
  parent_list,
97
- top_score_index,
91
+ top_scores_index,
98
92
  seq_lens.to(torch.int32),
99
93
  tree_mask,
100
94
  positions,
101
95
  retrive_index,
96
+ retrive_next_token,
97
+ retrive_next_sibling,
102
98
  topk,
103
- depth,
104
- draft_token,
105
- grid=(bs, 1, 1),
106
- block=(64, 1, 1),
99
+ spec_steps,
100
+ num_verify_tokens,
101
+ )
102
+ return (
103
+ tree_mask,
104
+ positions,
105
+ retrive_index,
106
+ retrive_next_token,
107
+ retrive_next_sibling,
108
+ draft_tokens,
109
+ )
110
+
111
+
112
+ def build_tree_kernel(
113
+ verified_id: torch.Tensor,
114
+ score_list: List[torch.Tensor],
115
+ token_list: List[torch.Tensor],
116
+ parents_list: List[torch.Tensor],
117
+ seq_lens: torch.Tensor,
118
+ seq_lens_sum: int,
119
+ topk: int,
120
+ spec_steps: int,
121
+ num_verify_tokens: int,
122
+ ):
123
+ parent_list, top_scores_index, draft_tokens = (
124
+ build_tree_kernel_efficient_preprocess(
125
+ verified_id,
126
+ score_list,
127
+ token_list,
128
+ parents_list,
129
+ num_verify_tokens,
130
+ )
131
+ )
132
+
133
+ bs = seq_lens.numel()
134
+ device = seq_lens.device
135
+
136
+ tree_mask = torch.full(
137
+ (
138
+ seq_lens_sum * num_verify_tokens
139
+ + num_verify_tokens * num_verify_tokens * bs,
140
+ ),
141
+ True,
142
+ device=device,
143
+ )
144
+ retrive_index = torch.full(
145
+ (bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long
146
+ )
147
+ positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
148
+
149
+ sgl_build_tree_kernel(
150
+ parent_list,
151
+ top_scores_index,
152
+ seq_lens.to(torch.int32),
153
+ tree_mask,
154
+ positions,
155
+ retrive_index,
156
+ topk,
157
+ spec_steps,
158
+ num_verify_tokens,
107
159
  )
108
- index = retrive_index.sum(dim=-1) != -depth - 2
160
+
161
+ index = retrive_index.sum(dim=-1) != -spec_steps - 2
109
162
  cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
110
163
  retrive_cum_len = torch.zeros(
111
164
  (cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
112
165
  )
113
166
  retrive_cum_len[1:] = cum_len
167
+ # TODO: this indexing cause a synchronization, optimize this
114
168
  retrive_index = retrive_index[index]
115
- return tree_mask, positions, retrive_index, retrive_cum_len
169
+ return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens
116
170
 
117
171
 
118
- if __name__ == "__main__":
119
-
172
+ def test_build_tree_kernel():
120
173
  def findp(p_i, index, parent_list):
121
174
  pos = index // 10
122
175
  index_list = index.tolist()
@@ -309,21 +362,21 @@ if __name__ == "__main__":
309
362
  bs = verified_seq_len.shape[0]
310
363
  topk = 10
311
364
  depth = 5 # depth <= 10
312
- draft_token = 64
365
+ num_draft_token = 64
313
366
 
314
367
  tree_mask = torch.full(
315
368
  (
316
- torch.sum(verified_seq_len).item() * draft_token
317
- + draft_token * draft_token * bs,
369
+ torch.sum(verified_seq_len).item() * num_draft_token
370
+ + num_draft_token * num_draft_token * bs,
318
371
  ),
319
372
  True,
320
373
  ).cuda()
321
374
  retrive_index = torch.full(
322
- (bs, draft_token, depth + 2), -1, device="cuda", dtype=torch.long
375
+ (bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long
323
376
  )
324
- positions = torch.empty((bs * draft_token,), device="cuda", dtype=torch.long)
377
+ positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long)
325
378
 
326
- kernels.build_tree(
379
+ sgl_build_tree_kernel(
327
380
  parent_list.unsqueeze(0),
328
381
  index.unsqueeze(0),
329
382
  verified_seq_len,
@@ -332,16 +385,345 @@ if __name__ == "__main__":
332
385
  retrive_index,
333
386
  topk,
334
387
  depth,
335
- draft_token,
336
- grid=(bs, 1, 1),
337
- block=(64, 1, 1),
388
+ num_draft_token,
338
389
  )
390
+
339
391
  retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
340
392
 
341
393
  c_mask, c_positions, c_retive_index = create_mask(
342
- verified_seq_len, draft_token, index, parent_list, depth
394
+ verified_seq_len, num_draft_token, index, parent_list, depth
343
395
  )
344
396
 
345
397
  assert torch.allclose(tree_mask, c_mask), "tree mask has error."
346
398
  assert torch.allclose(positions, c_positions), "positions has error."
347
399
  assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
400
+
401
+
402
+ def test_build_tree_kernel_efficient():
403
+ verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
404
+ score_list = [
405
+ torch.tensor(
406
+ [
407
+ [[7.1127e-01, 2.8292e-01, 2.2995e-03, 1.7357e-03]],
408
+ [[9.7476e-01, 2.2219e-02, 6.5031e-04, 1.3212e-04]],
409
+ ],
410
+ dtype=torch.float32,
411
+ device="cuda",
412
+ ),
413
+ torch.tensor(
414
+ [
415
+ [
416
+ [6.9142e-01, 1.2863e-02, 1.6873e-03, 1.1871e-03],
417
+ [2.4787e-01, 1.8818e-02, 1.4204e-02, 9.2235e-04],
418
+ [2.2971e-03, 1.6700e-06, 1.8737e-07, 8.3146e-08],
419
+ [1.2771e-03, 2.4374e-04, 1.7832e-04, 1.1947e-05],
420
+ ],
421
+ [
422
+ [8.4832e-02, 6.6068e-02, 5.8304e-02, 5.7851e-02],
423
+ [2.3616e-03, 1.1243e-03, 5.4368e-04, 2.7768e-04],
424
+ [2.5286e-04, 1.5578e-04, 2.8817e-05, 1.2888e-05],
425
+ [1.2834e-04, 2.5417e-06, 1.1279e-06, 1.6088e-08],
426
+ ],
427
+ ],
428
+ dtype=torch.float32,
429
+ device="cuda",
430
+ ),
431
+ torch.tensor(
432
+ [
433
+ [
434
+ [6.6438e-01, 2.6997e-02, 2.4236e-05, 4.0821e-06],
435
+ [2.4402e-01, 2.8409e-03, 5.0935e-04, 2.9022e-04],
436
+ [1.6178e-02, 2.0567e-03, 4.5892e-04, 3.0034e-05],
437
+ [1.3023e-02, 5.0497e-04, 3.6371e-04, 8.7750e-05],
438
+ ],
439
+ [
440
+ [2.3263e-02, 2.0054e-02, 9.3990e-03, 2.7783e-03],
441
+ [6.4156e-02, 5.5506e-04, 1.0429e-04, 9.7211e-05],
442
+ [4.9950e-02, 5.0630e-03, 9.0068e-04, 3.3656e-04],
443
+ [7.5817e-03, 8.5731e-04, 6.9972e-04, 6.0793e-04],
444
+ ],
445
+ ],
446
+ dtype=torch.float32,
447
+ device="cuda",
448
+ ),
449
+ torch.tensor(
450
+ [
451
+ [
452
+ [6.6420e-01, 1.0525e-04, 6.5864e-05, 1.2253e-06],
453
+ [1.3019e-01, 1.0461e-01, 5.2083e-03, 1.6777e-03],
454
+ [2.0103e-02, 6.7335e-03, 1.2625e-04, 1.0364e-05],
455
+ [1.5142e-02, 7.0819e-04, 9.6595e-05, 8.7951e-05],
456
+ ],
457
+ [
458
+ [5.8608e-02, 1.8840e-03, 7.8535e-04, 4.4400e-04],
459
+ [1.2185e-02, 2.0684e-03, 1.7418e-03, 1.4327e-03],
460
+ [6.2455e-03, 6.1487e-03, 2.6862e-03, 1.8034e-03],
461
+ [1.8590e-03, 1.6151e-03, 1.2481e-03, 3.6038e-04],
462
+ ],
463
+ ],
464
+ dtype=torch.float32,
465
+ device="cuda",
466
+ ),
467
+ ]
468
+ token_list = [
469
+ torch.tensor(
470
+ [[29896, 29906, 29900, 29945], [13, 2, 29871, 28956]],
471
+ dtype=torch.int64,
472
+ device="cuda",
473
+ ),
474
+ torch.tensor(
475
+ [
476
+ [
477
+ 29889,
478
+ 29974,
479
+ 29945,
480
+ 29900,
481
+ 29974,
482
+ 29922,
483
+ 29930,
484
+ 29958,
485
+ 29889,
486
+ 29974,
487
+ 29930,
488
+ 29945,
489
+ 29974,
490
+ 29922,
491
+ 29930,
492
+ 29958,
493
+ ],
494
+ [
495
+ 22550,
496
+ 4136,
497
+ 16492,
498
+ 8439,
499
+ 29871,
500
+ 2,
501
+ 3001,
502
+ 13,
503
+ 2,
504
+ 13,
505
+ 29906,
506
+ 29946,
507
+ 2,
508
+ 13,
509
+ 29871,
510
+ 259,
511
+ ],
512
+ ],
513
+ device="cuda",
514
+ ),
515
+ torch.tensor(
516
+ [
517
+ [
518
+ 29946,
519
+ 29945,
520
+ 29953,
521
+ 29906,
522
+ 29896,
523
+ 29945,
524
+ 29900,
525
+ 29906,
526
+ 29896,
527
+ 29945,
528
+ 29906,
529
+ 29953,
530
+ 29896,
531
+ 29945,
532
+ 29906,
533
+ 29946,
534
+ ],
535
+ [
536
+ 29871,
537
+ 2,
538
+ 29901,
539
+ 29889,
540
+ 29871,
541
+ 2,
542
+ 395,
543
+ 259,
544
+ 29901,
545
+ 29871,
546
+ 2,
547
+ 29889,
548
+ 3001,
549
+ 1234,
550
+ 7146,
551
+ 2186,
552
+ ],
553
+ ],
554
+ device="cuda",
555
+ ),
556
+ torch.tensor(
557
+ [
558
+ [
559
+ 29946,
560
+ 29974,
561
+ 29945,
562
+ 29930,
563
+ 29889,
564
+ 29922,
565
+ 29974,
566
+ 29930,
567
+ 29974,
568
+ 29946,
569
+ 29930,
570
+ 29922,
571
+ 29889,
572
+ 29974,
573
+ 29945,
574
+ 29922,
575
+ ],
576
+ [
577
+ 29941,
578
+ 29906,
579
+ 2,
580
+ 29946,
581
+ 29871,
582
+ 450,
583
+ 319,
584
+ 14990,
585
+ 29946,
586
+ 29941,
587
+ 2,
588
+ 29906,
589
+ 29871,
590
+ 2,
591
+ 3001,
592
+ 13,
593
+ ],
594
+ ],
595
+ device="cuda",
596
+ ),
597
+ ]
598
+ parents_list = [
599
+ torch.tensor(
600
+ [[-1, 0, 1, 2, 3], [-1, 0, 1, 2, 3]], dtype=torch.int64, device="cuda"
601
+ ),
602
+ torch.tensor([[4, 8, 9, 10], [4, 5, 6, 7]], dtype=torch.int64, device="cuda"),
603
+ torch.tensor(
604
+ [[20, 24, 21, 28], [24, 28, 20, 21]], dtype=torch.int64, device="cuda"
605
+ ),
606
+ torch.tensor(
607
+ [[36, 40, 41, 44], [36, 40, 44, 45]], dtype=torch.int64, device="cuda"
608
+ ),
609
+ ]
610
+ seq_lens = torch.tensor([5, 10], dtype=torch.int64, device="cuda")
611
+ topk = 4
612
+ depth = 4
613
+ num_draft_token = 8
614
+
615
+ tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
616
+ build_tree_kernel(
617
+ verified_id=verified_id,
618
+ score_list=score_list,
619
+ token_list=token_list,
620
+ parents_list=parents_list,
621
+ seq_lens=seq_lens,
622
+ seq_lens_sum=torch.sum(seq_lens).item(),
623
+ topk=topk,
624
+ spec_steps=depth,
625
+ num_verify_tokens=num_draft_token,
626
+ )
627
+ )
628
+
629
+ from sglang.srt.utils import first_rank_print
630
+
631
+ first_rank_print("=========== build tree kernel ==========")
632
+ # first_rank_print(f"{tree_mask=}", flush=True)
633
+ first_rank_print(f"{position=}", flush=True)
634
+ first_rank_print(f"{retrive_index=}", flush=True)
635
+ first_rank_print(f"{retrive_cum_len=}", flush=True)
636
+ first_rank_print(f"{draft_tokens=}", flush=True)
637
+ assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
638
+ assert retrive_index.tolist() == [
639
+ [0, -1, -1, -1, -1, -1],
640
+ [0, 2, 4, 6, -1, -1],
641
+ [0, 1, 3, 5, 7, -1],
642
+ [8, -1, -1, -1, -1, -1],
643
+ [8, 9, 10, -1, -1, -1],
644
+ [8, 9, 12, -1, -1, -1],
645
+ [8, 9, 13, -1, -1, -1],
646
+ [8, 9, 11, 14, 15, -1],
647
+ ]
648
+ assert retrive_cum_len.tolist() == [0, 3, 8]
649
+ assert draft_tokens.tolist() == [
650
+ 29974,
651
+ 29896,
652
+ 29906,
653
+ 29889,
654
+ 29974,
655
+ 29946,
656
+ 29896,
657
+ 29946,
658
+ 13,
659
+ 13,
660
+ 22550,
661
+ 4136,
662
+ 16492,
663
+ 8439,
664
+ 29871,
665
+ 29941,
666
+ ]
667
+
668
+ (
669
+ tree_mask,
670
+ position,
671
+ retrive_index,
672
+ retrive_next_token,
673
+ retrive_next_sibling,
674
+ draft_tokens,
675
+ ) = build_tree_kernel_efficient(
676
+ verified_id=verified_id,
677
+ score_list=score_list,
678
+ token_list=token_list,
679
+ parents_list=parents_list,
680
+ seq_lens=seq_lens,
681
+ seq_lens_sum=torch.sum(seq_lens).item(),
682
+ topk=topk,
683
+ spec_steps=depth,
684
+ num_verify_tokens=num_draft_token,
685
+ )
686
+
687
+ first_rank_print("=========== build tree kernel efficient ==========")
688
+ # first_rank_print(f"{tree_mask=}", flush=True)
689
+ first_rank_print(f"{position=}", flush=True)
690
+ first_rank_print(f"{retrive_index=}", flush=True)
691
+ first_rank_print(f"{retrive_next_token=}", flush=True)
692
+ first_rank_print(f"{retrive_next_sibling=}", flush=True)
693
+ first_rank_print(f"{draft_tokens=}", flush=True)
694
+ assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
695
+ assert retrive_index.tolist() == [
696
+ [0, 1, 2, 3, 4, 5, 6, 7],
697
+ [8, 9, 10, 11, 12, 13, 14, 15],
698
+ ]
699
+ assert retrive_next_token.tolist() == [
700
+ [1, 3, 4, 5, 6, 7, -1, -1],
701
+ [1, 2, -1, 6, -1, -1, 7, -1],
702
+ ]
703
+ assert retrive_next_sibling.tolist() == [
704
+ [-1, 2, -1, -1, -1, -1, -1, -1],
705
+ [-1, -1, 3, 4, 5, -1, -1, -1],
706
+ ]
707
+ assert draft_tokens.tolist() == [
708
+ 29974,
709
+ 29896,
710
+ 29906,
711
+ 29889,
712
+ 29974,
713
+ 29946,
714
+ 29896,
715
+ 29946,
716
+ 13,
717
+ 13,
718
+ 22550,
719
+ 4136,
720
+ 16492,
721
+ 8439,
722
+ 29871,
723
+ 29941,
724
+ ]
725
+
726
+
727
+ if __name__ == "__main__":
728
+ test_build_tree_kernel_efficient()
729
+ test_build_tree_kernel()