sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/srt/layers/attention/__init__.py +14 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +211 -81
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/logits_processor.py +167 -212
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/managers/io_struct.py +1 -2
- sglang/srt/managers/schedule_batch.py +26 -2
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +62 -26
- sglang/srt/managers/tokenizer_manager.py +22 -20
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/model_executor/cuda_graph_runner.py +118 -73
- sglang/srt/model_executor/forward_batch_info.py +33 -8
- sglang/srt/model_executor/model_runner.py +63 -61
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +97 -26
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +21 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +9 -5
- sglang/srt/server_args.py +108 -57
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +618 -0
- sglang/srt/speculative/eagle_worker.py +170 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +15 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/RECORD +63 -39
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -23,6 +23,7 @@ from typing import List, Optional
|
|
23
23
|
import torch
|
24
24
|
|
25
25
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
26
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
26
27
|
from sglang.srt.utils import (
|
27
28
|
get_amdgpu_memory_capacity,
|
28
29
|
get_hpu_memory_capacity,
|
@@ -42,7 +43,6 @@ class ServerArgs:
|
|
42
43
|
model_path: str
|
43
44
|
tokenizer_path: Optional[str] = None
|
44
45
|
tokenizer_mode: str = "auto"
|
45
|
-
skip_tokenizer_init: bool = False
|
46
46
|
load_format: str = "auto"
|
47
47
|
trust_remote_code: bool = True
|
48
48
|
dtype: str = "auto"
|
@@ -54,6 +54,7 @@ class ServerArgs:
|
|
54
54
|
chat_template: Optional[str] = None
|
55
55
|
is_embedding: bool = False
|
56
56
|
revision: Optional[str] = None
|
57
|
+
skip_tokenizer_init: bool = False
|
57
58
|
return_token_ids: bool = False
|
58
59
|
|
59
60
|
# Port for the HTTP server
|
@@ -108,14 +109,6 @@ class ServerArgs:
|
|
108
109
|
# Model override args in JSON
|
109
110
|
json_model_override_args: str = "{}"
|
110
111
|
|
111
|
-
# Double Sparsity
|
112
|
-
enable_double_sparsity: bool = False
|
113
|
-
ds_channel_config_path: str = None
|
114
|
-
ds_heavy_channel_num: int = 32
|
115
|
-
ds_heavy_token_num: int = 256
|
116
|
-
ds_heavy_channel_type: str = "qk"
|
117
|
-
ds_sparse_decode_threshold: int = 4096
|
118
|
-
|
119
112
|
# LoRA
|
120
113
|
lora_paths: Optional[List[str]] = None
|
121
114
|
max_loras_per_batch: int = 8
|
@@ -125,6 +118,21 @@ class ServerArgs:
|
|
125
118
|
sampling_backend: Optional[str] = None
|
126
119
|
grammar_backend: Optional[str] = "outlines"
|
127
120
|
|
121
|
+
# Speculative decoding
|
122
|
+
speculative_draft_model_path: Optional[str] = None
|
123
|
+
speculative_algorithm: Optional[str] = None
|
124
|
+
speculative_num_steps: int = 5
|
125
|
+
speculative_num_draft_tokens: int = 64
|
126
|
+
speculative_eagle_topk: int = 8
|
127
|
+
|
128
|
+
# Double Sparsity
|
129
|
+
enable_double_sparsity: bool = False
|
130
|
+
ds_channel_config_path: str = None
|
131
|
+
ds_heavy_channel_num: int = 32
|
132
|
+
ds_heavy_token_num: int = 256
|
133
|
+
ds_heavy_channel_type: str = "qk"
|
134
|
+
ds_sparse_decode_threshold: int = 4096
|
135
|
+
|
128
136
|
# Optimization/debug options
|
129
137
|
disable_radix_cache: bool = False
|
130
138
|
disable_jump_forward: bool = False
|
@@ -240,6 +248,17 @@ class ServerArgs:
|
|
240
248
|
"Overlap scheduler is disabled."
|
241
249
|
)
|
242
250
|
|
251
|
+
# Speculative Decoding
|
252
|
+
if self.speculative_algorithm == "EAGLE":
|
253
|
+
self.prefill_only_one_req = True
|
254
|
+
self.disable_cuda_graph_padding = True
|
255
|
+
self.disable_radix_cache = True
|
256
|
+
self.disable_overlap_schedule = True
|
257
|
+
self.chunked_prefill_size = -1
|
258
|
+
logger.info(
|
259
|
+
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
|
260
|
+
)
|
261
|
+
|
243
262
|
# GGUF
|
244
263
|
if (
|
245
264
|
self.load_format == "auto" or self.load_format == "gguf"
|
@@ -276,17 +295,6 @@ class ServerArgs:
|
|
276
295
|
"tokenizer if available, and 'slow' will "
|
277
296
|
"always use the slow tokenizer.",
|
278
297
|
)
|
279
|
-
parser.add_argument(
|
280
|
-
"--skip-tokenizer-init",
|
281
|
-
action="store_true",
|
282
|
-
help="If set, skip init tokenizer and pass input_ids in generate request",
|
283
|
-
)
|
284
|
-
parser.add_argument(
|
285
|
-
"--return-token-ids",
|
286
|
-
action="store_true",
|
287
|
-
default=ServerArgs.return_token_ids,
|
288
|
-
help="Whether to return token IDs in the output, this may introduce additional overhead.",
|
289
|
-
)
|
290
298
|
parser.add_argument(
|
291
299
|
"--load-format",
|
292
300
|
type=str,
|
@@ -394,6 +402,17 @@ class ServerArgs:
|
|
394
402
|
"name, a tag name, or a commit id. If unspecified, will use "
|
395
403
|
"the default version.",
|
396
404
|
)
|
405
|
+
parser.add_argument(
|
406
|
+
"--skip-tokenizer-init",
|
407
|
+
action="store_true",
|
408
|
+
help="If set, skip init tokenizer and pass input_ids in generate request",
|
409
|
+
)
|
410
|
+
parser.add_argument(
|
411
|
+
"--return-token-ids",
|
412
|
+
action="store_true",
|
413
|
+
default=ServerArgs.return_token_ids,
|
414
|
+
help="Whether to return token IDs in the output, this may introduce additional overhead.",
|
415
|
+
)
|
397
416
|
|
398
417
|
# Memory and scheduling
|
399
418
|
parser.add_argument(
|
@@ -602,43 +621,6 @@ class ServerArgs:
|
|
602
621
|
default=ServerArgs.json_model_override_args,
|
603
622
|
)
|
604
623
|
|
605
|
-
# Double Sparsity
|
606
|
-
parser.add_argument(
|
607
|
-
"--enable-double-sparsity",
|
608
|
-
action="store_true",
|
609
|
-
help="Enable double sparsity attention",
|
610
|
-
)
|
611
|
-
parser.add_argument(
|
612
|
-
"--ds-channel-config-path",
|
613
|
-
type=str,
|
614
|
-
default=ServerArgs.ds_channel_config_path,
|
615
|
-
help="The path of the double sparsity channel config",
|
616
|
-
)
|
617
|
-
parser.add_argument(
|
618
|
-
"--ds-heavy-channel-num",
|
619
|
-
type=int,
|
620
|
-
default=ServerArgs.ds_heavy_channel_num,
|
621
|
-
help="The number of heavy channels in double sparsity attention",
|
622
|
-
)
|
623
|
-
parser.add_argument(
|
624
|
-
"--ds-heavy-token-num",
|
625
|
-
type=int,
|
626
|
-
default=ServerArgs.ds_heavy_token_num,
|
627
|
-
help="The number of heavy tokens in double sparsity attention",
|
628
|
-
)
|
629
|
-
parser.add_argument(
|
630
|
-
"--ds-heavy-channel-type",
|
631
|
-
type=str,
|
632
|
-
default=ServerArgs.ds_heavy_channel_type,
|
633
|
-
help="The type of heavy channels in double sparsity attention",
|
634
|
-
)
|
635
|
-
parser.add_argument(
|
636
|
-
"--ds-sparse-decode-threshold",
|
637
|
-
type=int,
|
638
|
-
default=ServerArgs.ds_sparse_decode_threshold,
|
639
|
-
help="The type of heavy channels in double sparsity attention",
|
640
|
-
)
|
641
|
-
|
642
624
|
# LoRA
|
643
625
|
parser.add_argument(
|
644
626
|
"--lora-paths",
|
@@ -678,6 +660,75 @@ class ServerArgs:
|
|
678
660
|
help="Choose the backend for grammar-guided decoding.",
|
679
661
|
)
|
680
662
|
|
663
|
+
# Speculative decoding
|
664
|
+
parser.add_argument(
|
665
|
+
"--speculative-algorithm",
|
666
|
+
type=str,
|
667
|
+
choices=["EAGLE"],
|
668
|
+
help="Speculative algorithm.",
|
669
|
+
)
|
670
|
+
parser.add_argument(
|
671
|
+
"--speculative-draft-model-path",
|
672
|
+
type=str,
|
673
|
+
help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
|
674
|
+
)
|
675
|
+
parser.add_argument(
|
676
|
+
"--speculative-num-steps",
|
677
|
+
type=int,
|
678
|
+
help="The number of steps sampled from draft model in Speculative Decoding.",
|
679
|
+
default=ServerArgs.speculative_num_steps,
|
680
|
+
)
|
681
|
+
parser.add_argument(
|
682
|
+
"--speculative-num-draft-tokens",
|
683
|
+
type=int,
|
684
|
+
help="The number of token sampled from draft model in Speculative Decoding.",
|
685
|
+
default=ServerArgs.speculative_num_draft_tokens,
|
686
|
+
)
|
687
|
+
parser.add_argument(
|
688
|
+
"--speculative-eagle-topk",
|
689
|
+
type=int,
|
690
|
+
help="The number of token sampled from draft model in eagle2 each step.",
|
691
|
+
choices=[1, 2, 4, 8],
|
692
|
+
default=ServerArgs.speculative_eagle_topk,
|
693
|
+
)
|
694
|
+
|
695
|
+
# Double Sparsity
|
696
|
+
parser.add_argument(
|
697
|
+
"--enable-double-sparsity",
|
698
|
+
action="store_true",
|
699
|
+
help="Enable double sparsity attention",
|
700
|
+
)
|
701
|
+
parser.add_argument(
|
702
|
+
"--ds-channel-config-path",
|
703
|
+
type=str,
|
704
|
+
default=ServerArgs.ds_channel_config_path,
|
705
|
+
help="The path of the double sparsity channel config",
|
706
|
+
)
|
707
|
+
parser.add_argument(
|
708
|
+
"--ds-heavy-channel-num",
|
709
|
+
type=int,
|
710
|
+
default=ServerArgs.ds_heavy_channel_num,
|
711
|
+
help="The number of heavy channels in double sparsity attention",
|
712
|
+
)
|
713
|
+
parser.add_argument(
|
714
|
+
"--ds-heavy-token-num",
|
715
|
+
type=int,
|
716
|
+
default=ServerArgs.ds_heavy_token_num,
|
717
|
+
help="The number of heavy tokens in double sparsity attention",
|
718
|
+
)
|
719
|
+
parser.add_argument(
|
720
|
+
"--ds-heavy-channel-type",
|
721
|
+
type=str,
|
722
|
+
default=ServerArgs.ds_heavy_channel_type,
|
723
|
+
help="The type of heavy channels in double sparsity attention",
|
724
|
+
)
|
725
|
+
parser.add_argument(
|
726
|
+
"--ds-sparse-decode-threshold",
|
727
|
+
type=int,
|
728
|
+
default=ServerArgs.ds_sparse_decode_threshold,
|
729
|
+
help="The type of heavy channels in double sparsity attention",
|
730
|
+
)
|
731
|
+
|
681
732
|
# Optimization/debug options
|
682
733
|
parser.add_argument(
|
683
734
|
"--disable-radix-cache",
|
@@ -0,0 +1,347 @@
|
|
1
|
+
import cutex
|
2
|
+
import torch
|
3
|
+
|
4
|
+
# parent_table [bs,topk*depth+)]
|
5
|
+
# selected_index [bs,draft_token_num-1)]
|
6
|
+
# verified_seq_len [bs]
|
7
|
+
# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
|
8
|
+
# positions [bs*draft_token]
|
9
|
+
# retrive_index [b, draft_token, depth+2]
|
10
|
+
kernels = cutex.SourceModule(
|
11
|
+
"""
|
12
|
+
//cuda
|
13
|
+
__global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
|
14
|
+
Tensor<bool, 1> tree_mask, Tensor<long, 1> positions, Tensor<long, 3> retrive_index, int topk, int depth, int draft_token_num) {
|
15
|
+
int bid = blockIdx.x;
|
16
|
+
int tid = threadIdx.x;
|
17
|
+
if (tid >= draft_token_num){
|
18
|
+
return;
|
19
|
+
}
|
20
|
+
int seq_tree_idx = draft_token_num * draft_token_num * bid;
|
21
|
+
for(int i=0; i<bid; i++){
|
22
|
+
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
23
|
+
}
|
24
|
+
int seq_len = verified_seq_len[bid];
|
25
|
+
int token_tree_idx = seq_tree_idx + (seq_len+draft_token_num)*tid + seq_len + 1;
|
26
|
+
for(int i=0; i<draft_token_num-1; i++){
|
27
|
+
tree_mask[token_tree_idx+i] = false;
|
28
|
+
}
|
29
|
+
|
30
|
+
int position = 0;
|
31
|
+
if (tid==0){
|
32
|
+
positions[bid*draft_token_num] = seq_len;
|
33
|
+
retrive_index[bid][0][0] = bid * draft_token_num;
|
34
|
+
return;
|
35
|
+
}
|
36
|
+
|
37
|
+
int depends_order[10];
|
38
|
+
|
39
|
+
int cur_position = tid-1;
|
40
|
+
while(true){
|
41
|
+
depends_order[position] = cur_position+1;
|
42
|
+
position += 1;
|
43
|
+
tree_mask[token_tree_idx+cur_position] = true;
|
44
|
+
int parent_tb_idx = selected_index[bid][cur_position]/topk;
|
45
|
+
if(parent_tb_idx==0){
|
46
|
+
break;
|
47
|
+
}
|
48
|
+
|
49
|
+
int token_idx = parent_list[bid][parent_tb_idx];
|
50
|
+
for(cur_position=0; cur_position<draft_token_num;cur_position++){
|
51
|
+
if(selected_index[bid][cur_position]==token_idx){
|
52
|
+
break;
|
53
|
+
}
|
54
|
+
}
|
55
|
+
}
|
56
|
+
positions[bid*draft_token_num+tid] = position + seq_len;
|
57
|
+
|
58
|
+
int is_leaf = 0;
|
59
|
+
for(int i=1;i<draft_token_num;i++){
|
60
|
+
if(tree_mask[seq_tree_idx + i * (draft_token_num+seq_len) + seq_len + tid])
|
61
|
+
{
|
62
|
+
is_leaf ++;
|
63
|
+
}
|
64
|
+
}
|
65
|
+
if(is_leaf==1){
|
66
|
+
for(int i=0; i<position; i++){
|
67
|
+
retrive_index[bid][tid][position-i] = depends_order[i] + bid * draft_token_num;
|
68
|
+
}
|
69
|
+
retrive_index[bid][tid][0] = bid*draft_token_num;
|
70
|
+
}
|
71
|
+
|
72
|
+
|
73
|
+
|
74
|
+
}
|
75
|
+
//!cuda
|
76
|
+
""",
|
77
|
+
float_bits=16, # change to 16 to use half precision as `float` type in the above source code.
|
78
|
+
boundscheck=True, # turning on for debug and off for performance (to use full threads of a block), default is on.
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token):
|
83
|
+
bs = seq_lens.numel()
|
84
|
+
device = parent_list.device
|
85
|
+
tree_mask = torch.full(
|
86
|
+
(torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,),
|
87
|
+
True,
|
88
|
+
device=device,
|
89
|
+
)
|
90
|
+
retrive_index = torch.full(
|
91
|
+
(bs, draft_token, depth + 2), -1, device=device, dtype=torch.long
|
92
|
+
)
|
93
|
+
positions = torch.empty((bs * draft_token,), device=device, dtype=torch.long)
|
94
|
+
|
95
|
+
kernels.build_tree(
|
96
|
+
parent_list,
|
97
|
+
top_score_index,
|
98
|
+
seq_lens.to(torch.int32),
|
99
|
+
tree_mask,
|
100
|
+
positions,
|
101
|
+
retrive_index,
|
102
|
+
topk,
|
103
|
+
depth,
|
104
|
+
draft_token,
|
105
|
+
grid=(bs, 1, 1),
|
106
|
+
block=(64, 1, 1),
|
107
|
+
)
|
108
|
+
index = retrive_index.sum(dim=-1) != -depth - 2
|
109
|
+
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
|
110
|
+
retrive_cum_len = torch.zeros(
|
111
|
+
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
|
112
|
+
)
|
113
|
+
retrive_cum_len[1:] = cum_len
|
114
|
+
retrive_index = retrive_index[index]
|
115
|
+
return tree_mask, positions, retrive_index, retrive_cum_len
|
116
|
+
|
117
|
+
|
118
|
+
if __name__ == "__main__":
|
119
|
+
|
120
|
+
def findp(p_i, index, parent_list):
|
121
|
+
pos = index // 10
|
122
|
+
index_list = index.tolist()
|
123
|
+
parent_list = parent_list.tolist()
|
124
|
+
res = [p_i]
|
125
|
+
while True:
|
126
|
+
p = pos[p_i]
|
127
|
+
if p == 0:
|
128
|
+
break
|
129
|
+
token_idx = parent_list[p]
|
130
|
+
p_i = index_list.index(token_idx)
|
131
|
+
res.append(p_i)
|
132
|
+
return res
|
133
|
+
|
134
|
+
def create_mask(seq_len, draft_token, index, parent_list, max_depth):
|
135
|
+
mask = []
|
136
|
+
positions = []
|
137
|
+
retrive_index = []
|
138
|
+
for i, lens in enumerate(seq_len.tolist()):
|
139
|
+
first_mask = torch.full((lens + draft_token,), True)
|
140
|
+
first_mask[-(draft_token - 1) :] = False
|
141
|
+
positions.append(lens)
|
142
|
+
mask.append(first_mask)
|
143
|
+
seq_order = []
|
144
|
+
first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
|
145
|
+
r_index = [first_index]
|
146
|
+
for j in range(draft_token - 1):
|
147
|
+
mask.append(torch.full((lens + 1,), True))
|
148
|
+
idx = findp(j, index, parent_list)
|
149
|
+
|
150
|
+
seq_order.append(idx)
|
151
|
+
positions.append(len(idx) + seq_len)
|
152
|
+
t = torch.full((draft_token - 1,), False)
|
153
|
+
t[idx] = True
|
154
|
+
mask.append(t)
|
155
|
+
|
156
|
+
for i in range(1, draft_token - 1):
|
157
|
+
is_leaf = 0
|
158
|
+
for j in range(draft_token - 1):
|
159
|
+
if i in seq_order[j]:
|
160
|
+
is_leaf += 1
|
161
|
+
|
162
|
+
if is_leaf == 1:
|
163
|
+
order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
|
164
|
+
for _ in range(max_depth + 1 - len(seq_order[i])):
|
165
|
+
order_list.append(-1)
|
166
|
+
order = torch.Tensor(order_list).cuda().to(torch.long)
|
167
|
+
r_index.append(order)
|
168
|
+
retrive_index.append(torch.stack(r_index))
|
169
|
+
|
170
|
+
return (
|
171
|
+
torch.cat(mask).cuda(),
|
172
|
+
torch.Tensor(positions).cuda().to(torch.long),
|
173
|
+
torch.stack(retrive_index),
|
174
|
+
)
|
175
|
+
|
176
|
+
index = (
|
177
|
+
torch.Tensor(
|
178
|
+
[
|
179
|
+
0,
|
180
|
+
1,
|
181
|
+
2,
|
182
|
+
3,
|
183
|
+
10,
|
184
|
+
11,
|
185
|
+
12,
|
186
|
+
13,
|
187
|
+
20,
|
188
|
+
21,
|
189
|
+
22,
|
190
|
+
30,
|
191
|
+
110,
|
192
|
+
130,
|
193
|
+
150,
|
194
|
+
160,
|
195
|
+
210,
|
196
|
+
211,
|
197
|
+
212,
|
198
|
+
213,
|
199
|
+
214,
|
200
|
+
215,
|
201
|
+
216,
|
202
|
+
217,
|
203
|
+
218,
|
204
|
+
219,
|
205
|
+
220,
|
206
|
+
230,
|
207
|
+
310,
|
208
|
+
311,
|
209
|
+
312,
|
210
|
+
313,
|
211
|
+
314,
|
212
|
+
315,
|
213
|
+
316,
|
214
|
+
317,
|
215
|
+
320,
|
216
|
+
321,
|
217
|
+
322,
|
218
|
+
330,
|
219
|
+
360,
|
220
|
+
380,
|
221
|
+
390,
|
222
|
+
410,
|
223
|
+
411,
|
224
|
+
412,
|
225
|
+
413,
|
226
|
+
414,
|
227
|
+
415,
|
228
|
+
416,
|
229
|
+
417,
|
230
|
+
418,
|
231
|
+
419,
|
232
|
+
420,
|
233
|
+
421,
|
234
|
+
422,
|
235
|
+
423,
|
236
|
+
430,
|
237
|
+
431,
|
238
|
+
440,
|
239
|
+
441,
|
240
|
+
460,
|
241
|
+
470,
|
242
|
+
]
|
243
|
+
)
|
244
|
+
.to(torch.long)
|
245
|
+
.cuda()
|
246
|
+
)
|
247
|
+
|
248
|
+
parent_list = (
|
249
|
+
torch.Tensor(
|
250
|
+
[
|
251
|
+
-1,
|
252
|
+
0,
|
253
|
+
1,
|
254
|
+
2,
|
255
|
+
3,
|
256
|
+
4,
|
257
|
+
5,
|
258
|
+
6,
|
259
|
+
7,
|
260
|
+
8,
|
261
|
+
9,
|
262
|
+
10,
|
263
|
+
11,
|
264
|
+
12,
|
265
|
+
20,
|
266
|
+
30,
|
267
|
+
21,
|
268
|
+
13,
|
269
|
+
22,
|
270
|
+
40,
|
271
|
+
23,
|
272
|
+
110,
|
273
|
+
130,
|
274
|
+
160,
|
275
|
+
150,
|
276
|
+
190,
|
277
|
+
120,
|
278
|
+
111,
|
279
|
+
121,
|
280
|
+
200,
|
281
|
+
180,
|
282
|
+
210,
|
283
|
+
211,
|
284
|
+
212,
|
285
|
+
213,
|
286
|
+
214,
|
287
|
+
215,
|
288
|
+
216,
|
289
|
+
220,
|
290
|
+
230,
|
291
|
+
217,
|
292
|
+
310,
|
293
|
+
311,
|
294
|
+
312,
|
295
|
+
313,
|
296
|
+
320,
|
297
|
+
314,
|
298
|
+
321,
|
299
|
+
315,
|
300
|
+
316,
|
301
|
+
317,
|
302
|
+
]
|
303
|
+
)
|
304
|
+
.to(torch.long)
|
305
|
+
.cuda()
|
306
|
+
)
|
307
|
+
|
308
|
+
verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
|
309
|
+
bs = verified_seq_len.shape[0]
|
310
|
+
topk = 10
|
311
|
+
depth = 5 # depth <= 10
|
312
|
+
draft_token = 64
|
313
|
+
|
314
|
+
tree_mask = torch.full(
|
315
|
+
(
|
316
|
+
torch.sum(verified_seq_len).item() * draft_token
|
317
|
+
+ draft_token * draft_token * bs,
|
318
|
+
),
|
319
|
+
True,
|
320
|
+
).cuda()
|
321
|
+
retrive_index = torch.full(
|
322
|
+
(bs, draft_token, depth + 2), -1, device="cuda", dtype=torch.long
|
323
|
+
)
|
324
|
+
positions = torch.empty((bs * draft_token,), device="cuda", dtype=torch.long)
|
325
|
+
|
326
|
+
kernels.build_tree(
|
327
|
+
parent_list.unsqueeze(0),
|
328
|
+
index.unsqueeze(0),
|
329
|
+
verified_seq_len,
|
330
|
+
tree_mask,
|
331
|
+
positions,
|
332
|
+
retrive_index,
|
333
|
+
topk,
|
334
|
+
depth,
|
335
|
+
draft_token,
|
336
|
+
grid=(bs, 1, 1),
|
337
|
+
block=(64, 1, 1),
|
338
|
+
)
|
339
|
+
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
|
340
|
+
|
341
|
+
c_mask, c_positions, c_retive_index = create_mask(
|
342
|
+
verified_seq_len, draft_token, index, parent_list, depth
|
343
|
+
)
|
344
|
+
|
345
|
+
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
|
346
|
+
assert torch.allclose(positions, c_positions), "positions has error."
|
347
|
+
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
|