sglang 0.4.2.post2__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/function_call_parser.py +96 -69
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
- sglang/srt/layers/attention/triton_backend.py +64 -16
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +43 -10
- sglang/srt/models/llama.py +8 -3
- sglang/srt/speculative/build_eagle_tree.py +482 -102
- sglang/srt/speculative/eagle_utils.py +80 -50
- sglang/version.py +1 -1
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +2 -2
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +16 -16
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama.py
CHANGED
@@ -476,9 +476,14 @@ class LlamaForCausalLM(nn.Module):
|
|
476
476
|
# Skip loading kv_scale from ckpts towards new design.
|
477
477
|
if name.endswith(".kv_scale") and name not in params_dict:
|
478
478
|
continue
|
479
|
-
|
480
|
-
|
481
|
-
|
479
|
+
if name in params_dict.keys():
|
480
|
+
param = params_dict[name]
|
481
|
+
weight_loader = getattr(
|
482
|
+
param, "weight_loader", default_weight_loader
|
483
|
+
)
|
484
|
+
weight_loader(param, loaded_weight)
|
485
|
+
else:
|
486
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
482
487
|
|
483
488
|
def get_weights_by_name(
|
484
489
|
self, name: str, truncate_size: int = 100, tp_size: int = 1
|
@@ -1,124 +1,175 @@
|
|
1
|
-
|
1
|
+
# NOTE: Please run this file to make sure the test cases are correct.
|
2
|
+
|
3
|
+
from typing import List
|
4
|
+
|
2
5
|
import torch
|
3
6
|
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
)
|
7
|
+
from sglang.srt.utils import is_cuda_available
|
8
|
+
|
9
|
+
if is_cuda_available():
|
10
|
+
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
|
11
|
+
from sgl_kernel import (
|
12
|
+
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
def build_tree_kernel_efficient_preprocess(
|
17
|
+
verified_id: torch.Tensor,
|
18
|
+
score_list: List[torch.Tensor],
|
19
|
+
token_list: List[torch.Tensor],
|
20
|
+
parents_list: List[torch.Tensor],
|
21
|
+
num_verify_tokens: int,
|
22
|
+
):
|
23
|
+
score_list = torch.cat(score_list, dim=1).flatten(
|
24
|
+
1
|
25
|
+
) # b, n, topk; n= 1 + (num_steps-1) * self.topk
|
26
|
+
ss_token_list = torch.cat(
|
27
|
+
token_list, dim=1
|
28
|
+
) # b, (self.topk + (num_steps-1) * self.topk)
|
29
|
+
top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
|
30
|
+
top_scores_index = top_scores.indices
|
31
|
+
top_scores_index = torch.sort(top_scores_index).values
|
32
|
+
|
33
|
+
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
34
|
+
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
35
|
+
parent_list = torch.cat(parents_list[:-1], dim=1)
|
36
|
+
|
37
|
+
return parent_list, top_scores_index, draft_tokens
|
38
|
+
|
39
|
+
|
40
|
+
def build_tree_kernel_efficient(
|
41
|
+
verified_id: torch.Tensor,
|
42
|
+
score_list: List[torch.Tensor],
|
43
|
+
token_list: List[torch.Tensor],
|
44
|
+
parents_list: List[torch.Tensor],
|
45
|
+
seq_lens: torch.Tensor,
|
46
|
+
seq_lens_sum: int,
|
47
|
+
topk: int,
|
48
|
+
spec_steps: int,
|
49
|
+
num_verify_tokens: int,
|
50
|
+
):
|
51
|
+
parent_list, top_scores_index, draft_tokens = (
|
52
|
+
build_tree_kernel_efficient_preprocess(
|
53
|
+
verified_id,
|
54
|
+
score_list,
|
55
|
+
token_list,
|
56
|
+
parents_list,
|
57
|
+
num_verify_tokens,
|
58
|
+
)
|
59
|
+
)
|
60
|
+
|
61
|
+
# seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
|
62
|
+
bs = seq_lens.numel()
|
63
|
+
device = seq_lens.device
|
64
|
+
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
|
65
|
+
# where each row indicates the attending pattern of each draft token
|
66
|
+
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
|
67
|
+
tree_mask = torch.full(
|
68
|
+
(
|
69
|
+
seq_lens_sum * num_verify_tokens
|
70
|
+
+ num_verify_tokens * num_verify_tokens * bs,
|
71
|
+
),
|
72
|
+
True,
|
73
|
+
device=device,
|
74
|
+
)
|
75
|
+
retrive_index = torch.full(
|
76
|
+
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
77
|
+
)
|
78
|
+
retrive_next_token = torch.full(
|
79
|
+
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
80
|
+
)
|
81
|
+
retrive_next_sibling = torch.full(
|
82
|
+
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
83
|
+
)
|
84
|
+
# position: where each token belongs to
|
85
|
+
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
|
86
|
+
# then, positions = [7, 8, 8, 9]
|
87
|
+
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
|
88
|
+
|
89
|
+
sgl_build_tree_kernel_efficient(
|
90
|
+
parent_list,
|
91
|
+
top_scores_index,
|
92
|
+
seq_lens.to(torch.int32),
|
93
|
+
tree_mask,
|
94
|
+
positions,
|
95
|
+
retrive_index,
|
96
|
+
retrive_next_token,
|
97
|
+
retrive_next_sibling,
|
98
|
+
topk,
|
99
|
+
spec_steps,
|
100
|
+
num_verify_tokens,
|
101
|
+
)
|
102
|
+
return (
|
103
|
+
tree_mask,
|
104
|
+
positions,
|
105
|
+
retrive_index,
|
106
|
+
retrive_next_token,
|
107
|
+
retrive_next_sibling,
|
108
|
+
draft_tokens,
|
109
|
+
)
|
80
110
|
|
81
111
|
|
82
112
|
def build_tree_kernel(
|
83
|
-
|
113
|
+
verified_id: torch.Tensor,
|
114
|
+
score_list: List[torch.Tensor],
|
115
|
+
token_list: List[torch.Tensor],
|
116
|
+
parents_list: List[torch.Tensor],
|
117
|
+
seq_lens: torch.Tensor,
|
118
|
+
seq_lens_sum: int,
|
119
|
+
topk: int,
|
120
|
+
spec_steps: int,
|
121
|
+
num_verify_tokens: int,
|
84
122
|
):
|
123
|
+
parent_list, top_scores_index, draft_tokens = (
|
124
|
+
build_tree_kernel_efficient_preprocess(
|
125
|
+
verified_id,
|
126
|
+
score_list,
|
127
|
+
token_list,
|
128
|
+
parents_list,
|
129
|
+
num_verify_tokens,
|
130
|
+
)
|
131
|
+
)
|
132
|
+
|
85
133
|
bs = seq_lens.numel()
|
86
|
-
device =
|
134
|
+
device = seq_lens.device
|
135
|
+
|
87
136
|
tree_mask = torch.full(
|
88
|
-
(
|
137
|
+
(
|
138
|
+
seq_lens_sum * num_verify_tokens
|
139
|
+
+ num_verify_tokens * num_verify_tokens * bs,
|
140
|
+
),
|
89
141
|
True,
|
90
142
|
device=device,
|
91
143
|
)
|
92
144
|
retrive_index = torch.full(
|
93
|
-
(bs,
|
145
|
+
(bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long
|
94
146
|
)
|
95
|
-
positions = torch.empty((bs *
|
147
|
+
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
|
96
148
|
|
97
|
-
|
149
|
+
sgl_build_tree_kernel(
|
98
150
|
parent_list,
|
99
|
-
|
151
|
+
top_scores_index,
|
100
152
|
seq_lens.to(torch.int32),
|
101
153
|
tree_mask,
|
102
154
|
positions,
|
103
155
|
retrive_index,
|
104
156
|
topk,
|
105
|
-
|
106
|
-
|
107
|
-
grid=(bs, 1, 1),
|
108
|
-
block=(64, 1, 1),
|
157
|
+
spec_steps,
|
158
|
+
num_verify_tokens,
|
109
159
|
)
|
110
|
-
|
160
|
+
|
161
|
+
index = retrive_index.sum(dim=-1) != -spec_steps - 2
|
111
162
|
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
|
112
163
|
retrive_cum_len = torch.zeros(
|
113
164
|
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
|
114
165
|
)
|
115
166
|
retrive_cum_len[1:] = cum_len
|
167
|
+
# TODO: this indexing cause a synchronization, optimize this
|
116
168
|
retrive_index = retrive_index[index]
|
117
|
-
return tree_mask, positions, retrive_index, retrive_cum_len
|
169
|
+
return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens
|
118
170
|
|
119
171
|
|
120
|
-
|
121
|
-
|
172
|
+
def test_build_tree_kernel():
|
122
173
|
def findp(p_i, index, parent_list):
|
123
174
|
pos = index // 10
|
124
175
|
index_list = index.tolist()
|
@@ -311,21 +362,21 @@ if __name__ == "__main__":
|
|
311
362
|
bs = verified_seq_len.shape[0]
|
312
363
|
topk = 10
|
313
364
|
depth = 5 # depth <= 10
|
314
|
-
|
365
|
+
num_draft_token = 64
|
315
366
|
|
316
367
|
tree_mask = torch.full(
|
317
368
|
(
|
318
|
-
torch.sum(verified_seq_len).item() *
|
319
|
-
+
|
369
|
+
torch.sum(verified_seq_len).item() * num_draft_token
|
370
|
+
+ num_draft_token * num_draft_token * bs,
|
320
371
|
),
|
321
372
|
True,
|
322
373
|
).cuda()
|
323
374
|
retrive_index = torch.full(
|
324
|
-
(bs,
|
375
|
+
(bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long
|
325
376
|
)
|
326
|
-
positions = torch.empty((bs *
|
377
|
+
positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long)
|
327
378
|
|
328
|
-
|
379
|
+
sgl_build_tree_kernel(
|
329
380
|
parent_list.unsqueeze(0),
|
330
381
|
index.unsqueeze(0),
|
331
382
|
verified_seq_len,
|
@@ -334,16 +385,345 @@ if __name__ == "__main__":
|
|
334
385
|
retrive_index,
|
335
386
|
topk,
|
336
387
|
depth,
|
337
|
-
|
338
|
-
grid=(bs, 1, 1),
|
339
|
-
block=(64, 1, 1),
|
388
|
+
num_draft_token,
|
340
389
|
)
|
390
|
+
|
341
391
|
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
|
342
392
|
|
343
393
|
c_mask, c_positions, c_retive_index = create_mask(
|
344
|
-
verified_seq_len,
|
394
|
+
verified_seq_len, num_draft_token, index, parent_list, depth
|
345
395
|
)
|
346
396
|
|
347
397
|
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
|
348
398
|
assert torch.allclose(positions, c_positions), "positions has error."
|
349
399
|
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
|
400
|
+
|
401
|
+
|
402
|
+
def test_build_tree_kernel_efficient():
|
403
|
+
verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
|
404
|
+
score_list = [
|
405
|
+
torch.tensor(
|
406
|
+
[
|
407
|
+
[[7.1127e-01, 2.8292e-01, 2.2995e-03, 1.7357e-03]],
|
408
|
+
[[9.7476e-01, 2.2219e-02, 6.5031e-04, 1.3212e-04]],
|
409
|
+
],
|
410
|
+
dtype=torch.float32,
|
411
|
+
device="cuda",
|
412
|
+
),
|
413
|
+
torch.tensor(
|
414
|
+
[
|
415
|
+
[
|
416
|
+
[6.9142e-01, 1.2863e-02, 1.6873e-03, 1.1871e-03],
|
417
|
+
[2.4787e-01, 1.8818e-02, 1.4204e-02, 9.2235e-04],
|
418
|
+
[2.2971e-03, 1.6700e-06, 1.8737e-07, 8.3146e-08],
|
419
|
+
[1.2771e-03, 2.4374e-04, 1.7832e-04, 1.1947e-05],
|
420
|
+
],
|
421
|
+
[
|
422
|
+
[8.4832e-02, 6.6068e-02, 5.8304e-02, 5.7851e-02],
|
423
|
+
[2.3616e-03, 1.1243e-03, 5.4368e-04, 2.7768e-04],
|
424
|
+
[2.5286e-04, 1.5578e-04, 2.8817e-05, 1.2888e-05],
|
425
|
+
[1.2834e-04, 2.5417e-06, 1.1279e-06, 1.6088e-08],
|
426
|
+
],
|
427
|
+
],
|
428
|
+
dtype=torch.float32,
|
429
|
+
device="cuda",
|
430
|
+
),
|
431
|
+
torch.tensor(
|
432
|
+
[
|
433
|
+
[
|
434
|
+
[6.6438e-01, 2.6997e-02, 2.4236e-05, 4.0821e-06],
|
435
|
+
[2.4402e-01, 2.8409e-03, 5.0935e-04, 2.9022e-04],
|
436
|
+
[1.6178e-02, 2.0567e-03, 4.5892e-04, 3.0034e-05],
|
437
|
+
[1.3023e-02, 5.0497e-04, 3.6371e-04, 8.7750e-05],
|
438
|
+
],
|
439
|
+
[
|
440
|
+
[2.3263e-02, 2.0054e-02, 9.3990e-03, 2.7783e-03],
|
441
|
+
[6.4156e-02, 5.5506e-04, 1.0429e-04, 9.7211e-05],
|
442
|
+
[4.9950e-02, 5.0630e-03, 9.0068e-04, 3.3656e-04],
|
443
|
+
[7.5817e-03, 8.5731e-04, 6.9972e-04, 6.0793e-04],
|
444
|
+
],
|
445
|
+
],
|
446
|
+
dtype=torch.float32,
|
447
|
+
device="cuda",
|
448
|
+
),
|
449
|
+
torch.tensor(
|
450
|
+
[
|
451
|
+
[
|
452
|
+
[6.6420e-01, 1.0525e-04, 6.5864e-05, 1.2253e-06],
|
453
|
+
[1.3019e-01, 1.0461e-01, 5.2083e-03, 1.6777e-03],
|
454
|
+
[2.0103e-02, 6.7335e-03, 1.2625e-04, 1.0364e-05],
|
455
|
+
[1.5142e-02, 7.0819e-04, 9.6595e-05, 8.7951e-05],
|
456
|
+
],
|
457
|
+
[
|
458
|
+
[5.8608e-02, 1.8840e-03, 7.8535e-04, 4.4400e-04],
|
459
|
+
[1.2185e-02, 2.0684e-03, 1.7418e-03, 1.4327e-03],
|
460
|
+
[6.2455e-03, 6.1487e-03, 2.6862e-03, 1.8034e-03],
|
461
|
+
[1.8590e-03, 1.6151e-03, 1.2481e-03, 3.6038e-04],
|
462
|
+
],
|
463
|
+
],
|
464
|
+
dtype=torch.float32,
|
465
|
+
device="cuda",
|
466
|
+
),
|
467
|
+
]
|
468
|
+
token_list = [
|
469
|
+
torch.tensor(
|
470
|
+
[[29896, 29906, 29900, 29945], [13, 2, 29871, 28956]],
|
471
|
+
dtype=torch.int64,
|
472
|
+
device="cuda",
|
473
|
+
),
|
474
|
+
torch.tensor(
|
475
|
+
[
|
476
|
+
[
|
477
|
+
29889,
|
478
|
+
29974,
|
479
|
+
29945,
|
480
|
+
29900,
|
481
|
+
29974,
|
482
|
+
29922,
|
483
|
+
29930,
|
484
|
+
29958,
|
485
|
+
29889,
|
486
|
+
29974,
|
487
|
+
29930,
|
488
|
+
29945,
|
489
|
+
29974,
|
490
|
+
29922,
|
491
|
+
29930,
|
492
|
+
29958,
|
493
|
+
],
|
494
|
+
[
|
495
|
+
22550,
|
496
|
+
4136,
|
497
|
+
16492,
|
498
|
+
8439,
|
499
|
+
29871,
|
500
|
+
2,
|
501
|
+
3001,
|
502
|
+
13,
|
503
|
+
2,
|
504
|
+
13,
|
505
|
+
29906,
|
506
|
+
29946,
|
507
|
+
2,
|
508
|
+
13,
|
509
|
+
29871,
|
510
|
+
259,
|
511
|
+
],
|
512
|
+
],
|
513
|
+
device="cuda",
|
514
|
+
),
|
515
|
+
torch.tensor(
|
516
|
+
[
|
517
|
+
[
|
518
|
+
29946,
|
519
|
+
29945,
|
520
|
+
29953,
|
521
|
+
29906,
|
522
|
+
29896,
|
523
|
+
29945,
|
524
|
+
29900,
|
525
|
+
29906,
|
526
|
+
29896,
|
527
|
+
29945,
|
528
|
+
29906,
|
529
|
+
29953,
|
530
|
+
29896,
|
531
|
+
29945,
|
532
|
+
29906,
|
533
|
+
29946,
|
534
|
+
],
|
535
|
+
[
|
536
|
+
29871,
|
537
|
+
2,
|
538
|
+
29901,
|
539
|
+
29889,
|
540
|
+
29871,
|
541
|
+
2,
|
542
|
+
395,
|
543
|
+
259,
|
544
|
+
29901,
|
545
|
+
29871,
|
546
|
+
2,
|
547
|
+
29889,
|
548
|
+
3001,
|
549
|
+
1234,
|
550
|
+
7146,
|
551
|
+
2186,
|
552
|
+
],
|
553
|
+
],
|
554
|
+
device="cuda",
|
555
|
+
),
|
556
|
+
torch.tensor(
|
557
|
+
[
|
558
|
+
[
|
559
|
+
29946,
|
560
|
+
29974,
|
561
|
+
29945,
|
562
|
+
29930,
|
563
|
+
29889,
|
564
|
+
29922,
|
565
|
+
29974,
|
566
|
+
29930,
|
567
|
+
29974,
|
568
|
+
29946,
|
569
|
+
29930,
|
570
|
+
29922,
|
571
|
+
29889,
|
572
|
+
29974,
|
573
|
+
29945,
|
574
|
+
29922,
|
575
|
+
],
|
576
|
+
[
|
577
|
+
29941,
|
578
|
+
29906,
|
579
|
+
2,
|
580
|
+
29946,
|
581
|
+
29871,
|
582
|
+
450,
|
583
|
+
319,
|
584
|
+
14990,
|
585
|
+
29946,
|
586
|
+
29941,
|
587
|
+
2,
|
588
|
+
29906,
|
589
|
+
29871,
|
590
|
+
2,
|
591
|
+
3001,
|
592
|
+
13,
|
593
|
+
],
|
594
|
+
],
|
595
|
+
device="cuda",
|
596
|
+
),
|
597
|
+
]
|
598
|
+
parents_list = [
|
599
|
+
torch.tensor(
|
600
|
+
[[-1, 0, 1, 2, 3], [-1, 0, 1, 2, 3]], dtype=torch.int64, device="cuda"
|
601
|
+
),
|
602
|
+
torch.tensor([[4, 8, 9, 10], [4, 5, 6, 7]], dtype=torch.int64, device="cuda"),
|
603
|
+
torch.tensor(
|
604
|
+
[[20, 24, 21, 28], [24, 28, 20, 21]], dtype=torch.int64, device="cuda"
|
605
|
+
),
|
606
|
+
torch.tensor(
|
607
|
+
[[36, 40, 41, 44], [36, 40, 44, 45]], dtype=torch.int64, device="cuda"
|
608
|
+
),
|
609
|
+
]
|
610
|
+
seq_lens = torch.tensor([5, 10], dtype=torch.int64, device="cuda")
|
611
|
+
topk = 4
|
612
|
+
depth = 4
|
613
|
+
num_draft_token = 8
|
614
|
+
|
615
|
+
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
|
616
|
+
build_tree_kernel(
|
617
|
+
verified_id=verified_id,
|
618
|
+
score_list=score_list,
|
619
|
+
token_list=token_list,
|
620
|
+
parents_list=parents_list,
|
621
|
+
seq_lens=seq_lens,
|
622
|
+
seq_lens_sum=torch.sum(seq_lens).item(),
|
623
|
+
topk=topk,
|
624
|
+
spec_steps=depth,
|
625
|
+
num_verify_tokens=num_draft_token,
|
626
|
+
)
|
627
|
+
)
|
628
|
+
|
629
|
+
from sglang.srt.utils import first_rank_print
|
630
|
+
|
631
|
+
first_rank_print("=========== build tree kernel ==========")
|
632
|
+
# first_rank_print(f"{tree_mask=}", flush=True)
|
633
|
+
first_rank_print(f"{position=}", flush=True)
|
634
|
+
first_rank_print(f"{retrive_index=}", flush=True)
|
635
|
+
first_rank_print(f"{retrive_cum_len=}", flush=True)
|
636
|
+
first_rank_print(f"{draft_tokens=}", flush=True)
|
637
|
+
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
638
|
+
assert retrive_index.tolist() == [
|
639
|
+
[0, -1, -1, -1, -1, -1],
|
640
|
+
[0, 2, 4, 6, -1, -1],
|
641
|
+
[0, 1, 3, 5, 7, -1],
|
642
|
+
[8, -1, -1, -1, -1, -1],
|
643
|
+
[8, 9, 10, -1, -1, -1],
|
644
|
+
[8, 9, 12, -1, -1, -1],
|
645
|
+
[8, 9, 13, -1, -1, -1],
|
646
|
+
[8, 9, 11, 14, 15, -1],
|
647
|
+
]
|
648
|
+
assert retrive_cum_len.tolist() == [0, 3, 8]
|
649
|
+
assert draft_tokens.tolist() == [
|
650
|
+
29974,
|
651
|
+
29896,
|
652
|
+
29906,
|
653
|
+
29889,
|
654
|
+
29974,
|
655
|
+
29946,
|
656
|
+
29896,
|
657
|
+
29946,
|
658
|
+
13,
|
659
|
+
13,
|
660
|
+
22550,
|
661
|
+
4136,
|
662
|
+
16492,
|
663
|
+
8439,
|
664
|
+
29871,
|
665
|
+
29941,
|
666
|
+
]
|
667
|
+
|
668
|
+
(
|
669
|
+
tree_mask,
|
670
|
+
position,
|
671
|
+
retrive_index,
|
672
|
+
retrive_next_token,
|
673
|
+
retrive_next_sibling,
|
674
|
+
draft_tokens,
|
675
|
+
) = build_tree_kernel_efficient(
|
676
|
+
verified_id=verified_id,
|
677
|
+
score_list=score_list,
|
678
|
+
token_list=token_list,
|
679
|
+
parents_list=parents_list,
|
680
|
+
seq_lens=seq_lens,
|
681
|
+
seq_lens_sum=torch.sum(seq_lens).item(),
|
682
|
+
topk=topk,
|
683
|
+
spec_steps=depth,
|
684
|
+
num_verify_tokens=num_draft_token,
|
685
|
+
)
|
686
|
+
|
687
|
+
first_rank_print("=========== build tree kernel efficient ==========")
|
688
|
+
# first_rank_print(f"{tree_mask=}", flush=True)
|
689
|
+
first_rank_print(f"{position=}", flush=True)
|
690
|
+
first_rank_print(f"{retrive_index=}", flush=True)
|
691
|
+
first_rank_print(f"{retrive_next_token=}", flush=True)
|
692
|
+
first_rank_print(f"{retrive_next_sibling=}", flush=True)
|
693
|
+
first_rank_print(f"{draft_tokens=}", flush=True)
|
694
|
+
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
695
|
+
assert retrive_index.tolist() == [
|
696
|
+
[0, 1, 2, 3, 4, 5, 6, 7],
|
697
|
+
[8, 9, 10, 11, 12, 13, 14, 15],
|
698
|
+
]
|
699
|
+
assert retrive_next_token.tolist() == [
|
700
|
+
[1, 3, 4, 5, 6, 7, -1, -1],
|
701
|
+
[1, 2, -1, 6, -1, -1, 7, -1],
|
702
|
+
]
|
703
|
+
assert retrive_next_sibling.tolist() == [
|
704
|
+
[-1, 2, -1, -1, -1, -1, -1, -1],
|
705
|
+
[-1, -1, 3, 4, 5, -1, -1, -1],
|
706
|
+
]
|
707
|
+
assert draft_tokens.tolist() == [
|
708
|
+
29974,
|
709
|
+
29896,
|
710
|
+
29906,
|
711
|
+
29889,
|
712
|
+
29974,
|
713
|
+
29946,
|
714
|
+
29896,
|
715
|
+
29946,
|
716
|
+
13,
|
717
|
+
13,
|
718
|
+
22550,
|
719
|
+
4136,
|
720
|
+
16492,
|
721
|
+
8439,
|
722
|
+
29871,
|
723
|
+
29941,
|
724
|
+
]
|
725
|
+
|
726
|
+
|
727
|
+
if __name__ == "__main__":
|
728
|
+
test_build_tree_kernel_efficient()
|
729
|
+
test_build_tree_kernel()
|