sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ from sglang.srt.managers.schedule_batch import (
|
|
23
23
|
)
|
24
24
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
25
25
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
26
|
-
from sglang.srt.utils import
|
26
|
+
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
27
27
|
|
28
28
|
if is_cuda():
|
29
29
|
from sgl_kernel import (
|
@@ -32,6 +32,7 @@ if is_cuda():
|
|
32
32
|
tree_speculative_sampling_target_only,
|
33
33
|
verify_tree_greedy,
|
34
34
|
)
|
35
|
+
from sgl_kernel.top_k import fast_topk
|
35
36
|
elif is_hip():
|
36
37
|
from sgl_kernel import verify_tree_greedy
|
37
38
|
|
@@ -67,8 +68,6 @@ class EagleDraftInput:
|
|
67
68
|
kv_indptr: torch.Tensor = None
|
68
69
|
kv_indices: torch.Tensor = None
|
69
70
|
|
70
|
-
all_padding_lens: Optional[torch.Tensor] = None
|
71
|
-
|
72
71
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
73
72
|
# Prefill only generate 1 token.
|
74
73
|
assert len(self.verified_id) == len(batch.seq_lens)
|
@@ -93,6 +92,7 @@ class EagleDraftInput:
|
|
93
92
|
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
94
93
|
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
95
94
|
batch.return_logprob = False
|
95
|
+
batch.return_hidden_states = False
|
96
96
|
|
97
97
|
self.capture_hidden_mode = CaptureHiddenMode.LAST
|
98
98
|
self.accept_length.add_(1)
|
@@ -116,13 +116,14 @@ class EagleDraftInput:
|
|
116
116
|
req_to_token: torch.Tensor,
|
117
117
|
):
|
118
118
|
bs = self.accept_length.numel()
|
119
|
-
|
120
119
|
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
121
120
|
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
122
|
-
|
123
121
|
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
124
122
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
125
123
|
|
124
|
+
if paged_kernel_lens_sum is None:
|
125
|
+
paged_kernel_lens_sum = cum_kv_seq_len[-1]
|
126
|
+
|
126
127
|
kv_indices = torch.empty(
|
127
128
|
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
128
129
|
)
|
@@ -136,7 +137,6 @@ class EagleDraftInput:
|
|
136
137
|
kv_indices,
|
137
138
|
req_to_token.size(1),
|
138
139
|
)
|
139
|
-
|
140
140
|
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
141
141
|
|
142
142
|
def filter_batch(self, new_indices: torch.Tensor):
|
@@ -267,7 +267,7 @@ class EagleVerifyInput:
|
|
267
267
|
logits_output: torch.Tensor,
|
268
268
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
269
269
|
page_size: int,
|
270
|
-
vocab_mask: Optional[torch.Tensor] = None,
|
270
|
+
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
271
271
|
) -> torch.Tensor:
|
272
272
|
"""
|
273
273
|
Verify and find accepted tokens based on logits output and batch
|
@@ -291,6 +291,14 @@ class EagleVerifyInput:
|
|
291
291
|
)
|
292
292
|
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
293
293
|
|
294
|
+
# Apply the custom logit processors if registered in the sampling info.
|
295
|
+
if sampling_info.has_custom_logit_processor:
|
296
|
+
apply_custom_logit_processor(
|
297
|
+
logits_output.next_token_logits,
|
298
|
+
sampling_info,
|
299
|
+
num_tokens_in_batch=self.draft_token_num,
|
300
|
+
)
|
301
|
+
|
294
302
|
# Apply penalty
|
295
303
|
if sampling_info.penalizer_orchestrator.is_required:
|
296
304
|
# This is a relaxed version of penalties for speculative decoding.
|
@@ -320,11 +328,11 @@ class EagleVerifyInput:
|
|
320
328
|
predicts=predict, # mutable
|
321
329
|
accept_index=accept_index, # mutable
|
322
330
|
accept_token_num=accept_length, # mutable
|
323
|
-
candidates=candidates
|
324
|
-
retrive_index=self.retrive_index
|
325
|
-
retrive_next_token=self.retrive_next_token
|
326
|
-
retrive_next_sibling=self.retrive_next_sibling
|
327
|
-
target_predict=target_predict
|
331
|
+
candidates=candidates,
|
332
|
+
retrive_index=self.retrive_index,
|
333
|
+
retrive_next_token=self.retrive_next_token,
|
334
|
+
retrive_next_sibling=self.retrive_next_sibling,
|
335
|
+
target_predict=target_predict,
|
328
336
|
)
|
329
337
|
else:
|
330
338
|
# apply temperature and get target probs
|
@@ -352,16 +360,23 @@ class EagleVerifyInput:
|
|
352
360
|
draft_probs = torch.zeros(
|
353
361
|
target_probs.shape, dtype=torch.float32, device="cuda"
|
354
362
|
)
|
363
|
+
|
364
|
+
# coins for rejection sampling
|
355
365
|
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
|
366
|
+
# coins for final sampling
|
367
|
+
coins_for_final_sampling = torch.rand(
|
368
|
+
(bs,), dtype=torch.float32, device="cuda"
|
369
|
+
)
|
356
370
|
tree_speculative_sampling_target_only(
|
357
371
|
predicts=predict, # mutable
|
358
372
|
accept_index=accept_index, # mutable
|
359
373
|
accept_token_num=accept_length, # mutable
|
360
|
-
candidates=candidates
|
361
|
-
retrive_index=self.retrive_index
|
362
|
-
retrive_next_token=self.retrive_next_token
|
363
|
-
retrive_next_sibling=self.retrive_next_sibling
|
374
|
+
candidates=candidates,
|
375
|
+
retrive_index=self.retrive_index,
|
376
|
+
retrive_next_token=self.retrive_next_token,
|
377
|
+
retrive_next_sibling=self.retrive_next_sibling,
|
364
378
|
uniform_samples=coins,
|
379
|
+
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
365
380
|
target_probs=target_probs,
|
366
381
|
draft_probs=draft_probs,
|
367
382
|
threshold_single=global_server_args_dict[
|
@@ -384,8 +399,8 @@ class EagleVerifyInput:
|
|
384
399
|
spec_steps=self.spec_steps,
|
385
400
|
)
|
386
401
|
|
387
|
-
new_accept_index = []
|
388
402
|
unfinished_index = []
|
403
|
+
unfinished_accept_index = []
|
389
404
|
accept_index_cpu = accept_index.tolist()
|
390
405
|
predict_cpu = predict.tolist()
|
391
406
|
has_finished = False
|
@@ -393,12 +408,10 @@ class EagleVerifyInput:
|
|
393
408
|
# Iterate every accepted token and check if req has finished after append the token
|
394
409
|
# should be checked BEFORE free kv cache slots
|
395
410
|
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
396
|
-
new_accept_index_ = []
|
397
411
|
for j, idx in enumerate(accept_index_row):
|
398
412
|
if idx == -1:
|
399
413
|
break
|
400
414
|
id = predict_cpu[idx]
|
401
|
-
# if not found_finished:
|
402
415
|
req.output_ids.append(id)
|
403
416
|
req.check_finished()
|
404
417
|
if req.finished():
|
@@ -407,8 +420,6 @@ class EagleVerifyInput:
|
|
407
420
|
accept_index[i, j + 1 :] = -1
|
408
421
|
break
|
409
422
|
else:
|
410
|
-
new_accept_index_.append(idx)
|
411
|
-
# update grammar state
|
412
423
|
if req.grammar is not None:
|
413
424
|
try:
|
414
425
|
req.grammar.accept_token(id)
|
@@ -418,50 +429,104 @@ class EagleVerifyInput:
|
|
418
429
|
)
|
419
430
|
raise e
|
420
431
|
if not req.finished():
|
421
|
-
new_accept_index.extend(new_accept_index_)
|
422
432
|
unfinished_index.append(i)
|
433
|
+
if idx == -1:
|
434
|
+
unfinished_accept_index.append(accept_index[i, :j])
|
435
|
+
else:
|
436
|
+
unfinished_accept_index.append(accept_index[i])
|
423
437
|
req.spec_verify_ct += 1
|
424
438
|
|
425
439
|
if has_finished:
|
426
440
|
accept_length = (accept_index != -1).sum(dim=1) - 1
|
427
441
|
|
428
442
|
# Free the KV cache for unaccepted tokens
|
443
|
+
# TODO: fuse them
|
429
444
|
accept_index = accept_index[accept_index != -1]
|
430
445
|
verified_id = predict[accept_index]
|
431
446
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
432
447
|
evict_mask[accept_index] = False
|
433
448
|
|
434
|
-
if page_size
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
449
|
+
if page_size == 1:
|
450
|
+
# TODO: boolean array index leads to a device sync. Remove it.
|
451
|
+
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
452
|
+
else:
|
453
|
+
if self.topk == 1:
|
454
|
+
# Only evict full empty page. Do not evict partial empty page
|
455
|
+
align_evict_mask_to_page_size[len(batch.seq_lens),](
|
456
|
+
batch.seq_lens,
|
457
|
+
evict_mask,
|
458
|
+
page_size,
|
459
|
+
self.draft_token_num,
|
460
|
+
next_power_of_2(self.draft_token_num),
|
461
|
+
)
|
462
|
+
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
463
|
+
else:
|
464
|
+
# Shift the accepted tokens to the beginning.
|
465
|
+
# Only evict the last part
|
466
|
+
src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
|
467
|
+
batch.seq_lens,
|
468
|
+
batch.out_cache_loc,
|
469
|
+
accept_index,
|
470
|
+
accept_length,
|
471
|
+
self.draft_token_num,
|
472
|
+
page_size,
|
473
|
+
)
|
474
|
+
to_free_slots = torch.empty(
|
475
|
+
(to_free_num_slots.sum().item(),),
|
476
|
+
dtype=torch.int64,
|
477
|
+
device=to_free_num_slots.device,
|
478
|
+
)
|
479
|
+
|
480
|
+
# out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
|
481
|
+
# accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
|
482
|
+
# tgt_cache_loc: [0 1 , 3 4 , 6 ]
|
483
|
+
# to_free_slots: [ 2, 5, 7 8]
|
484
|
+
# to_free_slots also needs to be page-aligned without the first partial page
|
485
|
+
#
|
486
|
+
# split each row of out_cache_loc into two parts.
|
487
|
+
# 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
|
488
|
+
# 2. the second part goes to to_free_slots.
|
489
|
+
get_target_cache_loc[(bs,)](
|
490
|
+
tgt_cache_loc,
|
491
|
+
to_free_slots,
|
492
|
+
accept_length,
|
493
|
+
to_free_num_slots,
|
494
|
+
batch.out_cache_loc,
|
495
|
+
self.draft_token_num,
|
496
|
+
next_power_of_2(self.draft_token_num),
|
497
|
+
next_power_of_2(bs),
|
498
|
+
)
|
442
499
|
|
443
|
-
|
500
|
+
# Free the kv cache
|
501
|
+
token_to_kv_pool_allocator.free(to_free_slots)
|
502
|
+
|
503
|
+
# Copy the kv cache
|
504
|
+
batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
|
505
|
+
tgt_cache_loc, src_cache_loc
|
506
|
+
)
|
444
507
|
|
445
508
|
# Construct EagleVerifyOutput
|
446
509
|
if not has_finished:
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
510
|
+
if page_size == 1 or self.topk == 1:
|
511
|
+
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
512
|
+
assign_req_to_token_pool[(bs,)](
|
513
|
+
batch.req_pool_indices,
|
514
|
+
batch.req_to_token_pool.req_to_token,
|
515
|
+
batch.seq_lens,
|
516
|
+
batch.seq_lens + accept_length + 1,
|
517
|
+
batch.out_cache_loc,
|
518
|
+
batch.req_to_token_pool.req_to_token.shape[1],
|
519
|
+
next_power_of_2(bs),
|
520
|
+
)
|
521
|
+
else:
|
522
|
+
batch.out_cache_loc = tgt_cache_loc
|
457
523
|
batch.seq_lens.add_(accept_length + 1)
|
458
|
-
accept_length_cpu = accept_length.tolist()
|
459
524
|
|
460
525
|
draft_input = EagleDraftInput()
|
461
526
|
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
|
462
527
|
draft_input.verified_id = verified_id
|
463
528
|
draft_input.accept_length = accept_length
|
464
|
-
draft_input.accept_length_cpu =
|
529
|
+
draft_input.accept_length_cpu = accept_length.tolist()
|
465
530
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
466
531
|
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
|
467
532
|
|
@@ -469,47 +534,66 @@ class EagleVerifyInput:
|
|
469
534
|
draft_input=draft_input,
|
470
535
|
logits_output=logits_output,
|
471
536
|
verified_id=verified_id,
|
472
|
-
accept_length_per_req_cpu=accept_length_cpu,
|
537
|
+
accept_length_per_req_cpu=draft_input.accept_length_cpu,
|
473
538
|
accepted_indices=accept_index,
|
474
539
|
)
|
475
540
|
else:
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
541
|
+
if page_size == 1 or self.topk == 1:
|
542
|
+
assign_req_to_token_pool[(bs,)](
|
543
|
+
batch.req_pool_indices,
|
544
|
+
batch.req_to_token_pool.req_to_token,
|
545
|
+
batch.seq_lens,
|
546
|
+
batch.seq_lens + accept_length + 1,
|
547
|
+
batch.out_cache_loc[accept_index],
|
548
|
+
batch.req_to_token_pool.req_to_token.shape[1],
|
549
|
+
next_power_of_2(bs),
|
550
|
+
)
|
551
|
+
batch.seq_lens.add_(accept_length + 1)
|
487
552
|
|
553
|
+
accept_length_cpu = accept_length.tolist()
|
488
554
|
draft_input = EagleDraftInput()
|
489
|
-
if len(
|
490
|
-
|
491
|
-
unfinished_index_device = torch.tensor(
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
draft_input.verified_id = predict[new_accept_index]
|
496
|
-
draft_input.accept_length_cpu = [
|
555
|
+
if len(unfinished_accept_index) > 0:
|
556
|
+
unfinished_accept_index = torch.cat(unfinished_accept_index)
|
557
|
+
unfinished_index_device = torch.tensor(
|
558
|
+
unfinished_index, dtype=torch.int64, device=predict.device
|
559
|
+
)
|
560
|
+
draft_input_accept_length_cpu = [
|
497
561
|
accept_length_cpu[i] for i in unfinished_index
|
498
562
|
]
|
499
|
-
|
500
|
-
|
501
|
-
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
502
|
-
unfinished_index_device
|
503
|
-
]
|
504
|
-
draft_input.req_pool_indices_for_draft_extend = (
|
505
|
-
batch.req_pool_indices[unfinished_index_device]
|
506
|
-
)
|
563
|
+
if page_size == 1 or self.topk == 1:
|
564
|
+
batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
|
507
565
|
else:
|
508
|
-
|
509
|
-
|
510
|
-
|
566
|
+
batch.out_cache_loc = torch.empty(
|
567
|
+
len(unfinished_index) + sum(draft_input_accept_length_cpu),
|
568
|
+
dtype=torch.int64,
|
569
|
+
device=predict.device,
|
570
|
+
)
|
571
|
+
accept_length_filter = create_accept_length_filter(
|
572
|
+
accept_length,
|
573
|
+
unfinished_index_device,
|
574
|
+
batch.seq_lens,
|
511
575
|
)
|
512
|
-
|
576
|
+
filter_finished_cache_loc_kernel[(bs,)](
|
577
|
+
batch.out_cache_loc,
|
578
|
+
tgt_cache_loc,
|
579
|
+
accept_length,
|
580
|
+
accept_length_filter,
|
581
|
+
next_power_of_2(bs),
|
582
|
+
next_power_of_2(self.draft_token_num),
|
583
|
+
)
|
584
|
+
|
585
|
+
draft_input.hidden_states = batch.spec_info.hidden_states[
|
586
|
+
unfinished_accept_index
|
587
|
+
]
|
588
|
+
draft_input.verified_id = predict[unfinished_accept_index]
|
589
|
+
draft_input.accept_length_cpu = draft_input_accept_length_cpu
|
590
|
+
draft_input.accept_length = accept_length[unfinished_index_device]
|
591
|
+
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
592
|
+
unfinished_index_device
|
593
|
+
]
|
594
|
+
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
|
595
|
+
unfinished_index_device
|
596
|
+
]
|
513
597
|
|
514
598
|
return EagleVerifyOutput(
|
515
599
|
draft_input=draft_input,
|
@@ -586,36 +670,75 @@ def assign_draft_cache_locs(
|
|
586
670
|
req_pool_indices,
|
587
671
|
req_to_token,
|
588
672
|
seq_lens,
|
673
|
+
extend_lens,
|
674
|
+
num_new_pages_per_topk,
|
589
675
|
out_cache_loc,
|
590
676
|
pool_len: tl.constexpr,
|
591
677
|
topk: tl.constexpr,
|
592
678
|
speculative_num_steps: tl.constexpr,
|
593
679
|
page_size: tl.constexpr,
|
680
|
+
bs_upper: tl.constexpr,
|
681
|
+
iter_upper: tl.constexpr,
|
594
682
|
):
|
595
|
-
BLOCK_SIZE: tl.constexpr =
|
683
|
+
BLOCK_SIZE: tl.constexpr = 128
|
596
684
|
pid = tl.program_id(axis=0)
|
597
|
-
kv_start = tl.load(seq_lens + pid)
|
598
685
|
|
599
686
|
if page_size == 1 or topk == 1:
|
600
|
-
|
687
|
+
copy_len = topk * speculative_num_steps
|
601
688
|
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
602
689
|
else:
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
) // page_size
|
608
|
-
kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
|
690
|
+
bs_offset = tl.arange(0, bs_upper)
|
691
|
+
copy_len = tl.load(extend_lens + pid)
|
692
|
+
cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
|
693
|
+
out_cache_ptr = out_cache_loc + cum_copy_len
|
609
694
|
|
695
|
+
# Part 1: Copy from out_cache_loc to req_to_token
|
696
|
+
kv_start = tl.load(seq_lens + pid)
|
610
697
|
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
611
|
-
|
612
|
-
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
|
698
|
+
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
|
613
699
|
for i in range(num_loop):
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
700
|
+
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
701
|
+
mask = copy_offset < copy_len
|
702
|
+
data = tl.load(out_cache_ptr + copy_offset, mask=mask)
|
703
|
+
tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
|
704
|
+
|
705
|
+
if page_size == 1 or topk == 1:
|
706
|
+
return
|
707
|
+
|
708
|
+
# Part 2: Copy the indices for the last partial page
|
709
|
+
prefix_len = tl.load(seq_lens + pid)
|
710
|
+
last_page_len = prefix_len % page_size
|
711
|
+
offsets = tl.arange(0, page_size)
|
712
|
+
mask = offsets < last_page_len
|
713
|
+
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
|
714
|
+
prefix_base = token_pool + prefix_len - last_page_len
|
715
|
+
|
716
|
+
for topk_id in range(topk):
|
717
|
+
value = tl.load(prefix_base + offsets, mask=mask)
|
718
|
+
tl.store(
|
719
|
+
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
|
720
|
+
value,
|
721
|
+
mask=mask,
|
722
|
+
)
|
723
|
+
|
724
|
+
# Part 3: Remove the padding in out_cache_loc
|
725
|
+
iter_offest = tl.arange(0, iter_upper)
|
726
|
+
for topk_id in range(topk):
|
727
|
+
indices = tl.load(
|
728
|
+
prefix_base
|
729
|
+
+ topk_id * num_new_pages_per_topk_ * page_size
|
730
|
+
+ last_page_len
|
731
|
+
+ iter_offest,
|
732
|
+
mask=iter_offest < speculative_num_steps,
|
733
|
+
)
|
734
|
+
tl.store(
|
735
|
+
out_cache_loc
|
736
|
+
+ pid * topk * speculative_num_steps
|
737
|
+
+ topk_id * speculative_num_steps
|
738
|
+
+ iter_offest,
|
739
|
+
indices,
|
740
|
+
mask=iter_offest < speculative_num_steps,
|
741
|
+
)
|
619
742
|
|
620
743
|
|
621
744
|
@triton.jit
|
@@ -626,20 +749,23 @@ def generate_draft_decode_kv_indices(
|
|
626
749
|
kv_indices,
|
627
750
|
kv_indptr,
|
628
751
|
positions,
|
629
|
-
num_seqs: tl.constexpr,
|
630
|
-
topk: tl.constexpr,
|
631
752
|
pool_len: tl.constexpr,
|
632
753
|
kv_indices_stride: tl.constexpr,
|
633
754
|
kv_indptr_stride: tl.constexpr,
|
634
755
|
bs_upper: tl.constexpr,
|
635
756
|
iter_upper: tl.constexpr,
|
636
757
|
num_tokens_upper: tl.constexpr,
|
758
|
+
page_size: tl.constexpr,
|
637
759
|
):
|
638
760
|
BLOCK_SIZE: tl.constexpr = 128
|
639
761
|
iters = tl.program_id(axis=0)
|
640
762
|
bid = tl.program_id(axis=1)
|
641
763
|
topk_id = tl.program_id(axis=2)
|
642
764
|
|
765
|
+
num_steps = tl.num_programs(axis=0)
|
766
|
+
num_seqs = tl.num_programs(axis=1)
|
767
|
+
topk = tl.num_programs(axis=2)
|
768
|
+
|
643
769
|
kv_indices += kv_indices_stride * iters
|
644
770
|
kv_indptr += kv_indptr_stride * iters
|
645
771
|
iters += 1
|
@@ -649,6 +775,7 @@ def generate_draft_decode_kv_indices(
|
|
649
775
|
seq_len = tl.load(paged_kernel_lens + bid)
|
650
776
|
cum_seq_len = tl.sum(seq_lens)
|
651
777
|
|
778
|
+
# Update kv_indices
|
652
779
|
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
|
653
780
|
kv_ptr = kv_indices + kv_offset
|
654
781
|
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
|
@@ -662,10 +789,26 @@ def generate_draft_decode_kv_indices(
|
|
662
789
|
kv_offset += BLOCK_SIZE
|
663
790
|
|
664
791
|
extend_offset = tl.arange(0, iter_upper)
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
792
|
+
if page_size == 1 or topk == 1:
|
793
|
+
extend_data = tl.load(
|
794
|
+
token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
|
795
|
+
mask=extend_offset < iters,
|
796
|
+
)
|
797
|
+
else:
|
798
|
+
prefix_len = seq_len
|
799
|
+
last_page_len = prefix_len % page_size
|
800
|
+
num_new_pages_per_topk = (
|
801
|
+
last_page_len + num_steps + page_size - 1
|
802
|
+
) // page_size
|
803
|
+
prefix_base = seq_len // page_size * page_size
|
804
|
+
start = (
|
805
|
+
prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
|
806
|
+
)
|
807
|
+
extend_data = tl.load(
|
808
|
+
token_pool_ptr + start + extend_offset,
|
809
|
+
mask=extend_offset < iters,
|
810
|
+
)
|
811
|
+
|
669
812
|
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
|
670
813
|
|
671
814
|
# Update kv_indptr
|
@@ -704,6 +847,116 @@ def align_evict_mask_to_page_size(
|
|
704
847
|
tl.store(evict_mask + bid * num_draft_tokens + i, False)
|
705
848
|
|
706
849
|
|
850
|
+
@triton.jit
|
851
|
+
def get_target_cache_loc(
|
852
|
+
tgt_cache_loc,
|
853
|
+
to_free_slots,
|
854
|
+
accept_length,
|
855
|
+
to_free_num_slots,
|
856
|
+
out_cache_loc,
|
857
|
+
num_verify_tokens: tl.constexpr,
|
858
|
+
num_verify_tokens_upper: tl.constexpr,
|
859
|
+
bs_upper: tl.constexpr,
|
860
|
+
):
|
861
|
+
bid = tl.program_id(axis=0)
|
862
|
+
offset = tl.arange(0, num_verify_tokens_upper)
|
863
|
+
bs_offset = tl.arange(0, bs_upper)
|
864
|
+
|
865
|
+
# write the first part to tgt_cache_loc
|
866
|
+
accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
867
|
+
tgt_cache_loc_start = tl.sum(accept_len_all) + bid
|
868
|
+
copy_len = tl.load(accept_length + bid) + 1
|
869
|
+
out_cache_loc_row = tl.load(
|
870
|
+
out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
|
871
|
+
)
|
872
|
+
tl.store(
|
873
|
+
tgt_cache_loc + tgt_cache_loc_start + offset,
|
874
|
+
out_cache_loc_row,
|
875
|
+
mask=offset < copy_len,
|
876
|
+
)
|
877
|
+
|
878
|
+
# write the second part to to_free_num_pages
|
879
|
+
to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
|
880
|
+
to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
|
881
|
+
out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
|
882
|
+
to_free_slots_start = tl.sum(to_free_num_slots_all)
|
883
|
+
|
884
|
+
copy_len = to_free_num_slots_cur
|
885
|
+
out_cache_loc_row = tl.load(
|
886
|
+
out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
|
887
|
+
mask=offset < copy_len,
|
888
|
+
)
|
889
|
+
tl.store(
|
890
|
+
to_free_slots + to_free_slots_start + offset,
|
891
|
+
out_cache_loc_row,
|
892
|
+
mask=offset < copy_len,
|
893
|
+
)
|
894
|
+
|
895
|
+
|
896
|
+
@torch.compile(dynamic=True)
|
897
|
+
def get_src_tgt_cache_loc(
|
898
|
+
seq_lens: torch.Tensor,
|
899
|
+
out_cache_loc: torch.Tensor,
|
900
|
+
accept_index: torch.Tensor,
|
901
|
+
accept_length: torch.Tensor,
|
902
|
+
draft_token_num: int,
|
903
|
+
page_size: int,
|
904
|
+
):
|
905
|
+
src_cache_loc = out_cache_loc[accept_index]
|
906
|
+
tgt_cache_loc = torch.empty_like(src_cache_loc)
|
907
|
+
extended_len = seq_lens + draft_token_num
|
908
|
+
keep_len = torch.minimum(
|
909
|
+
(seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
|
910
|
+
extended_len,
|
911
|
+
)
|
912
|
+
to_free_num_slots = extended_len - keep_len
|
913
|
+
return src_cache_loc, tgt_cache_loc, to_free_num_slots
|
914
|
+
|
915
|
+
|
916
|
+
@triton.jit
|
917
|
+
def filter_finished_cache_loc_kernel(
|
918
|
+
out_cache_loc,
|
919
|
+
tgt_cache_loc,
|
920
|
+
accept_length,
|
921
|
+
accept_length_filter,
|
922
|
+
bs_upper: tl.constexpr,
|
923
|
+
num_verify_tokens_upper: tl.constexpr,
|
924
|
+
):
|
925
|
+
bid = tl.program_id(0)
|
926
|
+
bs_offset = tl.arange(0, bs_upper)
|
927
|
+
|
928
|
+
accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
929
|
+
old_start = tl.sum(accept_length_all) + bid
|
930
|
+
|
931
|
+
accept_length_filter_all = tl.load(
|
932
|
+
accept_length_filter + bs_offset, mask=bs_offset < bid
|
933
|
+
)
|
934
|
+
new_start = tl.sum(accept_length_filter_all)
|
935
|
+
|
936
|
+
copy_len = tl.load(accept_length_filter + bid)
|
937
|
+
copy_offset = tl.arange(0, num_verify_tokens_upper)
|
938
|
+
value = tl.load(
|
939
|
+
tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
|
940
|
+
)
|
941
|
+
tl.store(
|
942
|
+
out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
|
943
|
+
)
|
944
|
+
|
945
|
+
|
946
|
+
@torch.compile(dynamic=True)
|
947
|
+
def create_accept_length_filter(
|
948
|
+
accept_length: torch.Tensor,
|
949
|
+
unfinished_index_device: torch.Tensor,
|
950
|
+
seq_lens: torch.Tensor,
|
951
|
+
):
|
952
|
+
accept_length_filter = torch.zeros_like(accept_length)
|
953
|
+
accept_length_filter[unfinished_index_device] = (
|
954
|
+
accept_length[unfinished_index_device] + 1
|
955
|
+
)
|
956
|
+
seq_lens.add_(accept_length + 1)
|
957
|
+
return accept_length_filter
|
958
|
+
|
959
|
+
|
707
960
|
@torch.compile(dynamic=True)
|
708
961
|
def select_top_k_tokens(
|
709
962
|
i: int,
|
@@ -762,15 +1015,35 @@ def _generate_simulated_accept_index(
|
|
762
1015
|
spec_steps,
|
763
1016
|
):
|
764
1017
|
simulate_acc_len_float = float(simulate_acc_len)
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
1018
|
+
if SIMULATE_ACC_METHOD == "multinomial":
|
1019
|
+
simulated_values = torch.normal(
|
1020
|
+
mean=simulate_acc_len_float,
|
1021
|
+
std=1.0,
|
1022
|
+
size=(1,),
|
1023
|
+
device="cpu",
|
1024
|
+
)
|
1025
|
+
# clamp simulated values to be between 1 and self.spec_steps
|
1026
|
+
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
|
1027
|
+
simulate_acc_len = int(simulated_values.round().item())
|
1028
|
+
elif SIMULATE_ACC_METHOD == "match-expected":
|
1029
|
+
# multinomial sampling does not match the expected length
|
1030
|
+
# we keep it for the sake of compatibility of existing tests
|
1031
|
+
# but it's better to use "match-expected" for the cases that need to
|
1032
|
+
# match the expected length, One caveat is that this will only sample
|
1033
|
+
# either round down or round up of the expected length
|
1034
|
+
simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
|
1035
|
+
lower = int(simulate_acc_len_float // 1)
|
1036
|
+
upper = lower + 1 if lower < spec_steps + 1 else lower
|
1037
|
+
if lower == upper:
|
1038
|
+
simulate_acc_len = lower
|
1039
|
+
else:
|
1040
|
+
weight_upper = simulate_acc_len_float - lower
|
1041
|
+
weight_lower = 1.0 - weight_upper
|
1042
|
+
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
|
1043
|
+
sampled_index = torch.multinomial(probs, num_samples=1)
|
1044
|
+
simulate_acc_len = lower if sampled_index == 0 else upper
|
1045
|
+
else:
|
1046
|
+
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
|
774
1047
|
|
775
1048
|
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
|
776
1049
|
sim_accept_index = torch.full(
|
@@ -861,9 +1134,9 @@ def generate_token_bitmask(
|
|
861
1134
|
"""
|
862
1135
|
Generate the logit mask for structured output.
|
863
1136
|
Draft model's token can be either valid or invalid with respect to the grammar.
|
864
|
-
We need to perform DFS to
|
865
|
-
1. which tokens are accepted by the grammar
|
866
|
-
2. what is the corresponding logit mask.
|
1137
|
+
We need to perform DFS to
|
1138
|
+
1. figure out which tokens are accepted by the grammar.
|
1139
|
+
2. if so, what is the corresponding logit mask.
|
867
1140
|
"""
|
868
1141
|
|
869
1142
|
num_draft_tokens = draft_tokens_cpu.shape[-1]
|
@@ -880,6 +1153,7 @@ def generate_token_bitmask(
|
|
880
1153
|
device="cpu",
|
881
1154
|
)
|
882
1155
|
grammar = req.grammar
|
1156
|
+
s = time.perf_counter()
|
883
1157
|
traverse_tree(
|
884
1158
|
retrieve_next_token_cpu[i],
|
885
1159
|
retrieve_next_sibling_cpu[i],
|
@@ -889,6 +1163,12 @@ def generate_token_bitmask(
|
|
889
1163
|
i * num_draft_tokens : (i + 1) * num_draft_tokens
|
890
1164
|
],
|
891
1165
|
)
|
1166
|
+
tree_traverse_time = time.perf_counter() - s
|
1167
|
+
if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
|
1168
|
+
logger.warning(
|
1169
|
+
f"Bit mask generation took {tree_traverse_time} seconds with "
|
1170
|
+
f"grammar: {req.grammar}"
|
1171
|
+
)
|
892
1172
|
|
893
1173
|
verify_input.grammar = grammar
|
894
1174
|
return allocate_token_bitmask
|