sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/srt/layers/attention/__init__.py +14 -5
  3. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  4. sglang/srt/layers/attention/flashinfer_backend.py +211 -81
  5. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  6. sglang/srt/layers/attention/triton_backend.py +20 -11
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  8. sglang/srt/layers/logits_processor.py +167 -212
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
  31. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
  32. sglang/srt/layers/quantization/fp8.py +2 -2
  33. sglang/srt/layers/sampler.py +57 -21
  34. sglang/srt/layers/torchao_utils.py +17 -3
  35. sglang/srt/managers/io_struct.py +1 -2
  36. sglang/srt/managers/schedule_batch.py +26 -2
  37. sglang/srt/managers/schedule_policy.py +159 -90
  38. sglang/srt/managers/scheduler.py +62 -26
  39. sglang/srt/managers/tokenizer_manager.py +22 -20
  40. sglang/srt/managers/tp_worker.py +16 -4
  41. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  42. sglang/srt/model_executor/cuda_graph_runner.py +118 -73
  43. sglang/srt/model_executor/forward_batch_info.py +33 -8
  44. sglang/srt/model_executor/model_runner.py +63 -61
  45. sglang/srt/models/deepseek_v2.py +34 -7
  46. sglang/srt/models/grok.py +97 -26
  47. sglang/srt/openai_api/adapter.py +0 -17
  48. sglang/srt/openai_api/protocol.py +3 -3
  49. sglang/srt/sampling/sampling_batch_info.py +21 -0
  50. sglang/srt/sampling/sampling_params.py +9 -1
  51. sglang/srt/server.py +9 -5
  52. sglang/srt/server_args.py +108 -57
  53. sglang/srt/speculative/build_eagle_tree.py +347 -0
  54. sglang/srt/speculative/eagle_utils.py +618 -0
  55. sglang/srt/speculative/eagle_worker.py +170 -0
  56. sglang/srt/speculative/spec_info.py +5 -0
  57. sglang/srt/utils.py +15 -2
  58. sglang/version.py +1 -1
  59. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
  60. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/RECORD +63 -39
  61. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
  62. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
  63. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -23,6 +23,7 @@ from typing import List, Optional
23
23
  import torch
24
24
 
25
25
  from sglang.srt.hf_transformers_utils import check_gguf_file
26
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
26
27
  from sglang.srt.utils import (
27
28
  get_amdgpu_memory_capacity,
28
29
  get_hpu_memory_capacity,
@@ -42,7 +43,6 @@ class ServerArgs:
42
43
  model_path: str
43
44
  tokenizer_path: Optional[str] = None
44
45
  tokenizer_mode: str = "auto"
45
- skip_tokenizer_init: bool = False
46
46
  load_format: str = "auto"
47
47
  trust_remote_code: bool = True
48
48
  dtype: str = "auto"
@@ -54,6 +54,7 @@ class ServerArgs:
54
54
  chat_template: Optional[str] = None
55
55
  is_embedding: bool = False
56
56
  revision: Optional[str] = None
57
+ skip_tokenizer_init: bool = False
57
58
  return_token_ids: bool = False
58
59
 
59
60
  # Port for the HTTP server
@@ -108,14 +109,6 @@ class ServerArgs:
108
109
  # Model override args in JSON
109
110
  json_model_override_args: str = "{}"
110
111
 
111
- # Double Sparsity
112
- enable_double_sparsity: bool = False
113
- ds_channel_config_path: str = None
114
- ds_heavy_channel_num: int = 32
115
- ds_heavy_token_num: int = 256
116
- ds_heavy_channel_type: str = "qk"
117
- ds_sparse_decode_threshold: int = 4096
118
-
119
112
  # LoRA
120
113
  lora_paths: Optional[List[str]] = None
121
114
  max_loras_per_batch: int = 8
@@ -125,6 +118,21 @@ class ServerArgs:
125
118
  sampling_backend: Optional[str] = None
126
119
  grammar_backend: Optional[str] = "outlines"
127
120
 
121
+ # Speculative decoding
122
+ speculative_draft_model_path: Optional[str] = None
123
+ speculative_algorithm: Optional[str] = None
124
+ speculative_num_steps: int = 5
125
+ speculative_num_draft_tokens: int = 64
126
+ speculative_eagle_topk: int = 8
127
+
128
+ # Double Sparsity
129
+ enable_double_sparsity: bool = False
130
+ ds_channel_config_path: str = None
131
+ ds_heavy_channel_num: int = 32
132
+ ds_heavy_token_num: int = 256
133
+ ds_heavy_channel_type: str = "qk"
134
+ ds_sparse_decode_threshold: int = 4096
135
+
128
136
  # Optimization/debug options
129
137
  disable_radix_cache: bool = False
130
138
  disable_jump_forward: bool = False
@@ -240,6 +248,17 @@ class ServerArgs:
240
248
  "Overlap scheduler is disabled."
241
249
  )
242
250
 
251
+ # Speculative Decoding
252
+ if self.speculative_algorithm == "EAGLE":
253
+ self.prefill_only_one_req = True
254
+ self.disable_cuda_graph_padding = True
255
+ self.disable_radix_cache = True
256
+ self.disable_overlap_schedule = True
257
+ self.chunked_prefill_size = -1
258
+ logger.info(
259
+ "The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
260
+ )
261
+
243
262
  # GGUF
244
263
  if (
245
264
  self.load_format == "auto" or self.load_format == "gguf"
@@ -276,17 +295,6 @@ class ServerArgs:
276
295
  "tokenizer if available, and 'slow' will "
277
296
  "always use the slow tokenizer.",
278
297
  )
279
- parser.add_argument(
280
- "--skip-tokenizer-init",
281
- action="store_true",
282
- help="If set, skip init tokenizer and pass input_ids in generate request",
283
- )
284
- parser.add_argument(
285
- "--return-token-ids",
286
- action="store_true",
287
- default=ServerArgs.return_token_ids,
288
- help="Whether to return token IDs in the output, this may introduce additional overhead.",
289
- )
290
298
  parser.add_argument(
291
299
  "--load-format",
292
300
  type=str,
@@ -394,6 +402,17 @@ class ServerArgs:
394
402
  "name, a tag name, or a commit id. If unspecified, will use "
395
403
  "the default version.",
396
404
  )
405
+ parser.add_argument(
406
+ "--skip-tokenizer-init",
407
+ action="store_true",
408
+ help="If set, skip init tokenizer and pass input_ids in generate request",
409
+ )
410
+ parser.add_argument(
411
+ "--return-token-ids",
412
+ action="store_true",
413
+ default=ServerArgs.return_token_ids,
414
+ help="Whether to return token IDs in the output, this may introduce additional overhead.",
415
+ )
397
416
 
398
417
  # Memory and scheduling
399
418
  parser.add_argument(
@@ -602,43 +621,6 @@ class ServerArgs:
602
621
  default=ServerArgs.json_model_override_args,
603
622
  )
604
623
 
605
- # Double Sparsity
606
- parser.add_argument(
607
- "--enable-double-sparsity",
608
- action="store_true",
609
- help="Enable double sparsity attention",
610
- )
611
- parser.add_argument(
612
- "--ds-channel-config-path",
613
- type=str,
614
- default=ServerArgs.ds_channel_config_path,
615
- help="The path of the double sparsity channel config",
616
- )
617
- parser.add_argument(
618
- "--ds-heavy-channel-num",
619
- type=int,
620
- default=ServerArgs.ds_heavy_channel_num,
621
- help="The number of heavy channels in double sparsity attention",
622
- )
623
- parser.add_argument(
624
- "--ds-heavy-token-num",
625
- type=int,
626
- default=ServerArgs.ds_heavy_token_num,
627
- help="The number of heavy tokens in double sparsity attention",
628
- )
629
- parser.add_argument(
630
- "--ds-heavy-channel-type",
631
- type=str,
632
- default=ServerArgs.ds_heavy_channel_type,
633
- help="The type of heavy channels in double sparsity attention",
634
- )
635
- parser.add_argument(
636
- "--ds-sparse-decode-threshold",
637
- type=int,
638
- default=ServerArgs.ds_sparse_decode_threshold,
639
- help="The type of heavy channels in double sparsity attention",
640
- )
641
-
642
624
  # LoRA
643
625
  parser.add_argument(
644
626
  "--lora-paths",
@@ -678,6 +660,75 @@ class ServerArgs:
678
660
  help="Choose the backend for grammar-guided decoding.",
679
661
  )
680
662
 
663
+ # Speculative decoding
664
+ parser.add_argument(
665
+ "--speculative-algorithm",
666
+ type=str,
667
+ choices=["EAGLE"],
668
+ help="Speculative algorithm.",
669
+ )
670
+ parser.add_argument(
671
+ "--speculative-draft-model-path",
672
+ type=str,
673
+ help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
674
+ )
675
+ parser.add_argument(
676
+ "--speculative-num-steps",
677
+ type=int,
678
+ help="The number of steps sampled from draft model in Speculative Decoding.",
679
+ default=ServerArgs.speculative_num_steps,
680
+ )
681
+ parser.add_argument(
682
+ "--speculative-num-draft-tokens",
683
+ type=int,
684
+ help="The number of token sampled from draft model in Speculative Decoding.",
685
+ default=ServerArgs.speculative_num_draft_tokens,
686
+ )
687
+ parser.add_argument(
688
+ "--speculative-eagle-topk",
689
+ type=int,
690
+ help="The number of token sampled from draft model in eagle2 each step.",
691
+ choices=[1, 2, 4, 8],
692
+ default=ServerArgs.speculative_eagle_topk,
693
+ )
694
+
695
+ # Double Sparsity
696
+ parser.add_argument(
697
+ "--enable-double-sparsity",
698
+ action="store_true",
699
+ help="Enable double sparsity attention",
700
+ )
701
+ parser.add_argument(
702
+ "--ds-channel-config-path",
703
+ type=str,
704
+ default=ServerArgs.ds_channel_config_path,
705
+ help="The path of the double sparsity channel config",
706
+ )
707
+ parser.add_argument(
708
+ "--ds-heavy-channel-num",
709
+ type=int,
710
+ default=ServerArgs.ds_heavy_channel_num,
711
+ help="The number of heavy channels in double sparsity attention",
712
+ )
713
+ parser.add_argument(
714
+ "--ds-heavy-token-num",
715
+ type=int,
716
+ default=ServerArgs.ds_heavy_token_num,
717
+ help="The number of heavy tokens in double sparsity attention",
718
+ )
719
+ parser.add_argument(
720
+ "--ds-heavy-channel-type",
721
+ type=str,
722
+ default=ServerArgs.ds_heavy_channel_type,
723
+ help="The type of heavy channels in double sparsity attention",
724
+ )
725
+ parser.add_argument(
726
+ "--ds-sparse-decode-threshold",
727
+ type=int,
728
+ default=ServerArgs.ds_sparse_decode_threshold,
729
+ help="The type of heavy channels in double sparsity attention",
730
+ )
731
+
681
732
  # Optimization/debug options
682
733
  parser.add_argument(
683
734
  "--disable-radix-cache",
@@ -0,0 +1,347 @@
1
+ import cutex
2
+ import torch
3
+
4
+ # parent_table [bs,topk*depth+)]
5
+ # selected_index [bs,draft_token_num-1)]
6
+ # verified_seq_len [bs]
7
+ # tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
8
+ # positions [bs*draft_token]
9
+ # retrive_index [b, draft_token, depth+2]
10
+ kernels = cutex.SourceModule(
11
+ """
12
+ //cuda
13
+ __global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
14
+ Tensor<bool, 1> tree_mask, Tensor<long, 1> positions, Tensor<long, 3> retrive_index, int topk, int depth, int draft_token_num) {
15
+ int bid = blockIdx.x;
16
+ int tid = threadIdx.x;
17
+ if (tid >= draft_token_num){
18
+ return;
19
+ }
20
+ int seq_tree_idx = draft_token_num * draft_token_num * bid;
21
+ for(int i=0; i<bid; i++){
22
+ seq_tree_idx += verified_seq_len[i] * draft_token_num;
23
+ }
24
+ int seq_len = verified_seq_len[bid];
25
+ int token_tree_idx = seq_tree_idx + (seq_len+draft_token_num)*tid + seq_len + 1;
26
+ for(int i=0; i<draft_token_num-1; i++){
27
+ tree_mask[token_tree_idx+i] = false;
28
+ }
29
+
30
+ int position = 0;
31
+ if (tid==0){
32
+ positions[bid*draft_token_num] = seq_len;
33
+ retrive_index[bid][0][0] = bid * draft_token_num;
34
+ return;
35
+ }
36
+
37
+ int depends_order[10];
38
+
39
+ int cur_position = tid-1;
40
+ while(true){
41
+ depends_order[position] = cur_position+1;
42
+ position += 1;
43
+ tree_mask[token_tree_idx+cur_position] = true;
44
+ int parent_tb_idx = selected_index[bid][cur_position]/topk;
45
+ if(parent_tb_idx==0){
46
+ break;
47
+ }
48
+
49
+ int token_idx = parent_list[bid][parent_tb_idx];
50
+ for(cur_position=0; cur_position<draft_token_num;cur_position++){
51
+ if(selected_index[bid][cur_position]==token_idx){
52
+ break;
53
+ }
54
+ }
55
+ }
56
+ positions[bid*draft_token_num+tid] = position + seq_len;
57
+
58
+ int is_leaf = 0;
59
+ for(int i=1;i<draft_token_num;i++){
60
+ if(tree_mask[seq_tree_idx + i * (draft_token_num+seq_len) + seq_len + tid])
61
+ {
62
+ is_leaf ++;
63
+ }
64
+ }
65
+ if(is_leaf==1){
66
+ for(int i=0; i<position; i++){
67
+ retrive_index[bid][tid][position-i] = depends_order[i] + bid * draft_token_num;
68
+ }
69
+ retrive_index[bid][tid][0] = bid*draft_token_num;
70
+ }
71
+
72
+
73
+
74
+ }
75
+ //!cuda
76
+ """,
77
+ float_bits=16, # change to 16 to use half precision as `float` type in the above source code.
78
+ boundscheck=True, # turning on for debug and off for performance (to use full threads of a block), default is on.
79
+ )
80
+
81
+
82
+ def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token):
83
+ bs = seq_lens.numel()
84
+ device = parent_list.device
85
+ tree_mask = torch.full(
86
+ (torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,),
87
+ True,
88
+ device=device,
89
+ )
90
+ retrive_index = torch.full(
91
+ (bs, draft_token, depth + 2), -1, device=device, dtype=torch.long
92
+ )
93
+ positions = torch.empty((bs * draft_token,), device=device, dtype=torch.long)
94
+
95
+ kernels.build_tree(
96
+ parent_list,
97
+ top_score_index,
98
+ seq_lens.to(torch.int32),
99
+ tree_mask,
100
+ positions,
101
+ retrive_index,
102
+ topk,
103
+ depth,
104
+ draft_token,
105
+ grid=(bs, 1, 1),
106
+ block=(64, 1, 1),
107
+ )
108
+ index = retrive_index.sum(dim=-1) != -depth - 2
109
+ cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
110
+ retrive_cum_len = torch.zeros(
111
+ (cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
112
+ )
113
+ retrive_cum_len[1:] = cum_len
114
+ retrive_index = retrive_index[index]
115
+ return tree_mask, positions, retrive_index, retrive_cum_len
116
+
117
+
118
+ if __name__ == "__main__":
119
+
120
+ def findp(p_i, index, parent_list):
121
+ pos = index // 10
122
+ index_list = index.tolist()
123
+ parent_list = parent_list.tolist()
124
+ res = [p_i]
125
+ while True:
126
+ p = pos[p_i]
127
+ if p == 0:
128
+ break
129
+ token_idx = parent_list[p]
130
+ p_i = index_list.index(token_idx)
131
+ res.append(p_i)
132
+ return res
133
+
134
+ def create_mask(seq_len, draft_token, index, parent_list, max_depth):
135
+ mask = []
136
+ positions = []
137
+ retrive_index = []
138
+ for i, lens in enumerate(seq_len.tolist()):
139
+ first_mask = torch.full((lens + draft_token,), True)
140
+ first_mask[-(draft_token - 1) :] = False
141
+ positions.append(lens)
142
+ mask.append(first_mask)
143
+ seq_order = []
144
+ first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
145
+ r_index = [first_index]
146
+ for j in range(draft_token - 1):
147
+ mask.append(torch.full((lens + 1,), True))
148
+ idx = findp(j, index, parent_list)
149
+
150
+ seq_order.append(idx)
151
+ positions.append(len(idx) + seq_len)
152
+ t = torch.full((draft_token - 1,), False)
153
+ t[idx] = True
154
+ mask.append(t)
155
+
156
+ for i in range(1, draft_token - 1):
157
+ is_leaf = 0
158
+ for j in range(draft_token - 1):
159
+ if i in seq_order[j]:
160
+ is_leaf += 1
161
+
162
+ if is_leaf == 1:
163
+ order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
164
+ for _ in range(max_depth + 1 - len(seq_order[i])):
165
+ order_list.append(-1)
166
+ order = torch.Tensor(order_list).cuda().to(torch.long)
167
+ r_index.append(order)
168
+ retrive_index.append(torch.stack(r_index))
169
+
170
+ return (
171
+ torch.cat(mask).cuda(),
172
+ torch.Tensor(positions).cuda().to(torch.long),
173
+ torch.stack(retrive_index),
174
+ )
175
+
176
+ index = (
177
+ torch.Tensor(
178
+ [
179
+ 0,
180
+ 1,
181
+ 2,
182
+ 3,
183
+ 10,
184
+ 11,
185
+ 12,
186
+ 13,
187
+ 20,
188
+ 21,
189
+ 22,
190
+ 30,
191
+ 110,
192
+ 130,
193
+ 150,
194
+ 160,
195
+ 210,
196
+ 211,
197
+ 212,
198
+ 213,
199
+ 214,
200
+ 215,
201
+ 216,
202
+ 217,
203
+ 218,
204
+ 219,
205
+ 220,
206
+ 230,
207
+ 310,
208
+ 311,
209
+ 312,
210
+ 313,
211
+ 314,
212
+ 315,
213
+ 316,
214
+ 317,
215
+ 320,
216
+ 321,
217
+ 322,
218
+ 330,
219
+ 360,
220
+ 380,
221
+ 390,
222
+ 410,
223
+ 411,
224
+ 412,
225
+ 413,
226
+ 414,
227
+ 415,
228
+ 416,
229
+ 417,
230
+ 418,
231
+ 419,
232
+ 420,
233
+ 421,
234
+ 422,
235
+ 423,
236
+ 430,
237
+ 431,
238
+ 440,
239
+ 441,
240
+ 460,
241
+ 470,
242
+ ]
243
+ )
244
+ .to(torch.long)
245
+ .cuda()
246
+ )
247
+
248
+ parent_list = (
249
+ torch.Tensor(
250
+ [
251
+ -1,
252
+ 0,
253
+ 1,
254
+ 2,
255
+ 3,
256
+ 4,
257
+ 5,
258
+ 6,
259
+ 7,
260
+ 8,
261
+ 9,
262
+ 10,
263
+ 11,
264
+ 12,
265
+ 20,
266
+ 30,
267
+ 21,
268
+ 13,
269
+ 22,
270
+ 40,
271
+ 23,
272
+ 110,
273
+ 130,
274
+ 160,
275
+ 150,
276
+ 190,
277
+ 120,
278
+ 111,
279
+ 121,
280
+ 200,
281
+ 180,
282
+ 210,
283
+ 211,
284
+ 212,
285
+ 213,
286
+ 214,
287
+ 215,
288
+ 216,
289
+ 220,
290
+ 230,
291
+ 217,
292
+ 310,
293
+ 311,
294
+ 312,
295
+ 313,
296
+ 320,
297
+ 314,
298
+ 321,
299
+ 315,
300
+ 316,
301
+ 317,
302
+ ]
303
+ )
304
+ .to(torch.long)
305
+ .cuda()
306
+ )
307
+
308
+ verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
309
+ bs = verified_seq_len.shape[0]
310
+ topk = 10
311
+ depth = 5 # depth <= 10
312
+ draft_token = 64
313
+
314
+ tree_mask = torch.full(
315
+ (
316
+ torch.sum(verified_seq_len).item() * draft_token
317
+ + draft_token * draft_token * bs,
318
+ ),
319
+ True,
320
+ ).cuda()
321
+ retrive_index = torch.full(
322
+ (bs, draft_token, depth + 2), -1, device="cuda", dtype=torch.long
323
+ )
324
+ positions = torch.empty((bs * draft_token,), device="cuda", dtype=torch.long)
325
+
326
+ kernels.build_tree(
327
+ parent_list.unsqueeze(0),
328
+ index.unsqueeze(0),
329
+ verified_seq_len,
330
+ tree_mask,
331
+ positions,
332
+ retrive_index,
333
+ topk,
334
+ depth,
335
+ draft_token,
336
+ grid=(bs, 1, 1),
337
+ block=(64, 1, 1),
338
+ )
339
+ retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
340
+
341
+ c_mask, c_positions, c_retive_index = create_mask(
342
+ verified_seq_len, draft_token, index, parent_list, depth
343
+ )
344
+
345
+ assert torch.allclose(tree_mask, c_mask), "tree mask has error."
346
+ assert torch.allclose(positions, c_positions), "positions has error."
347
+ assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."