sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +72 -10
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +6 -16
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +582 -125
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +79 -6
- sglang/srt/layers/quantization/__init__.py +137 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +44 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -127
- sglang/srt/managers/scheduler.py +29 -23
- sglang/srt/managers/tokenizer_manager.py +1 -2
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +16 -13
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +64 -59
- sglang/srt/model_loader/loader.py +19 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +568 -0
- sglang/srt/models/deepseek_janus_pro.py +12 -17
- sglang/srt/models/deepseek_v2.py +339 -123
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +20 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +106 -93
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +120 -25
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +94 -25
- sglang/srt/utils.py +137 -51
- sglang/test/runners.py +27 -2
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +14 -27
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import os
|
3
4
|
from dataclasses import dataclass
|
4
5
|
from typing import TYPE_CHECKING, List, Optional
|
5
6
|
|
@@ -10,11 +11,15 @@ import triton.language as tl
|
|
10
11
|
|
11
12
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
12
13
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
13
|
-
from sglang.srt.managers.schedule_batch import
|
14
|
+
from sglang.srt.managers.schedule_batch import (
|
15
|
+
ScheduleBatch,
|
16
|
+
get_last_loc,
|
17
|
+
global_server_args_dict,
|
18
|
+
)
|
14
19
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
15
20
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
16
21
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
17
|
-
from sglang.srt.utils import is_cuda_available, is_hip
|
22
|
+
from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2
|
18
23
|
|
19
24
|
if is_cuda_available():
|
20
25
|
from sgl_kernel import (
|
@@ -34,6 +39,9 @@ import logging
|
|
34
39
|
logger = logging.getLogger(__name__)
|
35
40
|
|
36
41
|
|
42
|
+
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
|
43
|
+
|
44
|
+
|
37
45
|
@dataclass
|
38
46
|
class EagleDraftInput:
|
39
47
|
# The inputs for decode
|
@@ -93,7 +101,7 @@ class EagleDraftInput:
|
|
93
101
|
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
|
94
102
|
self.positions,
|
95
103
|
new_verified_id,
|
96
|
-
|
104
|
+
next_power_of_2(speculative_num_steps + 1),
|
97
105
|
)
|
98
106
|
|
99
107
|
batch.seq_lens_sum = sum(seq_lens_cpu)
|
@@ -225,18 +233,34 @@ class EagleVerifyInput:
|
|
225
233
|
CaptureHiddenMode.FULL,
|
226
234
|
)
|
227
235
|
|
228
|
-
def prepare_for_verify(self, batch: ScheduleBatch):
|
236
|
+
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
229
237
|
batch.input_ids = self.draft_token
|
230
|
-
|
238
|
+
|
239
|
+
if page_size == 1:
|
240
|
+
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
|
241
|
+
end_offset = batch.seq_lens + self.draft_token_num
|
242
|
+
else:
|
243
|
+
prefix_lens = batch.seq_lens
|
244
|
+
end_offset = prefix_lens + self.draft_token_num
|
245
|
+
last_loc = get_last_loc(
|
246
|
+
batch.req_to_token_pool.req_to_token,
|
247
|
+
batch.req_pool_indices,
|
248
|
+
prefix_lens,
|
249
|
+
)
|
250
|
+
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
|
251
|
+
prefix_lens, end_offset, last_loc, len(batch.input_ids)
|
252
|
+
)
|
253
|
+
self.last_loc = last_loc
|
254
|
+
|
231
255
|
bs = batch.batch_size()
|
232
256
|
assign_req_to_token_pool[(bs,)](
|
233
257
|
batch.req_pool_indices,
|
234
258
|
batch.req_to_token_pool.req_to_token,
|
235
259
|
batch.seq_lens,
|
236
|
-
|
260
|
+
end_offset,
|
237
261
|
batch.out_cache_loc,
|
238
262
|
batch.req_to_token_pool.req_to_token.shape[1],
|
239
|
-
|
263
|
+
next_power_of_2(bs),
|
240
264
|
)
|
241
265
|
|
242
266
|
def generate_attn_arg_prefill(
|
@@ -282,6 +306,7 @@ class EagleVerifyInput:
|
|
282
306
|
batch: ScheduleBatch,
|
283
307
|
logits_output: torch.Tensor,
|
284
308
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
309
|
+
page_size: int,
|
285
310
|
) -> torch.Tensor:
|
286
311
|
"""
|
287
312
|
Verify and find accepted tokens based on logits output and batch
|
@@ -305,6 +330,7 @@ class EagleVerifyInput:
|
|
305
330
|
)
|
306
331
|
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
307
332
|
|
333
|
+
# Apply penalty
|
308
334
|
if sampling_info.penalizer_orchestrator.is_required:
|
309
335
|
# This is a relaxed version of penalties for speculative decoding.
|
310
336
|
linear_penalty = torch.zeros(
|
@@ -317,6 +343,7 @@ class EagleVerifyInput:
|
|
317
343
|
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
318
344
|
)
|
319
345
|
|
346
|
+
# Sample tokens
|
320
347
|
if batch.sampling_info.is_all_greedy:
|
321
348
|
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
322
349
|
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
@@ -378,13 +405,24 @@ class EagleVerifyInput:
|
|
378
405
|
deterministic=True,
|
379
406
|
)
|
380
407
|
|
408
|
+
if SIMULATE_ACC_LEN:
|
409
|
+
# Do simulation
|
410
|
+
accept_index = _generate_simulated_accept_index(
|
411
|
+
accept_index=accept_index,
|
412
|
+
predict=predict, # mutable
|
413
|
+
accept_length=accept_length, # mutable
|
414
|
+
simulate_acc_len=SIMULATE_ACC_LEN,
|
415
|
+
bs=bs,
|
416
|
+
spec_steps=self.spec_steps,
|
417
|
+
)
|
418
|
+
|
381
419
|
new_accept_index = []
|
382
420
|
unfinished_index = []
|
383
421
|
accept_index_cpu = accept_index.tolist()
|
384
422
|
predict_cpu = predict.tolist()
|
385
423
|
has_finished = False
|
386
424
|
|
387
|
-
#
|
425
|
+
# Iterate every accepted token and check if req has finished after append the token
|
388
426
|
# should be checked BEFORE free kv cache slots
|
389
427
|
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
390
428
|
new_accept_index_ = []
|
@@ -407,13 +445,28 @@ class EagleVerifyInput:
|
|
407
445
|
unfinished_index.append(i)
|
408
446
|
req.spec_verify_ct += 1
|
409
447
|
|
448
|
+
if has_finished:
|
449
|
+
accept_length = (accept_index != -1).sum(dim=1) - 1
|
450
|
+
|
451
|
+
# Free the KV cache for unaccepted tokens
|
452
|
+
accept_index = accept_index[accept_index != -1]
|
453
|
+
verified_id = predict[accept_index]
|
454
|
+
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
455
|
+
evict_mask[accept_index] = False
|
456
|
+
|
457
|
+
if page_size != 1:
|
458
|
+
align_evict_mask_to_page_size[len(batch.seq_lens),](
|
459
|
+
batch.seq_lens,
|
460
|
+
evict_mask,
|
461
|
+
page_size,
|
462
|
+
self.draft_token_num,
|
463
|
+
next_power_of_2(self.draft_token_num),
|
464
|
+
)
|
465
|
+
|
466
|
+
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
467
|
+
|
468
|
+
# Construct EagleVerifyOutput
|
410
469
|
if not has_finished:
|
411
|
-
accept_index = accept_index[accept_index != -1]
|
412
|
-
verified_id = predict[accept_index]
|
413
|
-
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
414
|
-
evict_mask[accept_index] = False
|
415
|
-
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
416
|
-
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
417
470
|
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
418
471
|
assign_req_to_token_pool[(bs,)](
|
419
472
|
batch.req_pool_indices,
|
@@ -422,7 +475,7 @@ class EagleVerifyInput:
|
|
422
475
|
batch.seq_lens + accept_length + 1,
|
423
476
|
batch.out_cache_loc,
|
424
477
|
batch.req_to_token_pool.req_to_token.shape[1],
|
425
|
-
|
478
|
+
next_power_of_2(bs),
|
426
479
|
)
|
427
480
|
batch.seq_lens.add_(accept_length + 1)
|
428
481
|
accept_length_cpu = accept_length.tolist()
|
@@ -443,13 +496,6 @@ class EagleVerifyInput:
|
|
443
496
|
accepeted_indices=accept_index,
|
444
497
|
)
|
445
498
|
else:
|
446
|
-
accept_length = (accept_index != -1).sum(dim=1) - 1
|
447
|
-
accept_index = accept_index[accept_index != -1]
|
448
|
-
verified_id = predict[accept_index]
|
449
|
-
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
450
|
-
evict_mask[accept_index] = False
|
451
|
-
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
452
|
-
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
453
499
|
assign_req_to_token_pool[(bs,)](
|
454
500
|
batch.req_pool_indices,
|
455
501
|
batch.req_to_token_pool.req_to_token,
|
@@ -457,7 +503,7 @@ class EagleVerifyInput:
|
|
457
503
|
batch.seq_lens + accept_length + 1,
|
458
504
|
batch.out_cache_loc[accept_index],
|
459
505
|
batch.req_to_token_pool.req_to_token.shape[1],
|
460
|
-
|
506
|
+
next_power_of_2(bs),
|
461
507
|
)
|
462
508
|
batch.seq_lens.add_(accept_length + 1)
|
463
509
|
accept_length_cpu = accept_length.tolist()
|
@@ -465,20 +511,21 @@ class EagleVerifyInput:
|
|
465
511
|
draft_input = EagleDraftInput()
|
466
512
|
if len(new_accept_index) > 0:
|
467
513
|
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
514
|
+
unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
|
468
515
|
draft_input.hidden_states = batch.spec_info.hidden_states[
|
469
516
|
new_accept_index
|
470
517
|
]
|
471
518
|
draft_input.verified_id = predict[new_accept_index]
|
472
|
-
draft_input.accept_length = accept_length[unfinished_index]
|
473
519
|
draft_input.accept_length_cpu = [
|
474
520
|
accept_length_cpu[i] for i in unfinished_index
|
475
521
|
]
|
522
|
+
draft_input.accept_length = accept_length[unfinished_index_device]
|
476
523
|
if has_finished:
|
477
524
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
478
|
-
|
525
|
+
unfinished_index_device
|
479
526
|
]
|
480
527
|
draft_input.req_pool_indices_for_draft_extend = (
|
481
|
-
batch.req_pool_indices[
|
528
|
+
batch.req_pool_indices[unfinished_index_device]
|
482
529
|
)
|
483
530
|
else:
|
484
531
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
@@ -564,13 +611,24 @@ def assign_draft_cache_locs(
|
|
564
611
|
pool_len: tl.constexpr,
|
565
612
|
topk: tl.constexpr,
|
566
613
|
speculative_num_steps: tl.constexpr,
|
614
|
+
page_size: tl.constexpr,
|
567
615
|
):
|
568
616
|
BLOCK_SIZE: tl.constexpr = 32
|
569
617
|
pid = tl.program_id(axis=0)
|
570
618
|
kv_start = tl.load(seq_lens + pid)
|
571
|
-
|
619
|
+
|
620
|
+
if page_size == 1 or topk == 1:
|
621
|
+
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
|
622
|
+
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
623
|
+
else:
|
624
|
+
prefix_len = tl.load(seq_lens + pid)
|
625
|
+
last_page_len = prefix_len % page_size
|
626
|
+
num_new_page = (
|
627
|
+
last_page_len + speculative_num_steps + page_size - 1
|
628
|
+
) // page_size
|
629
|
+
kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
|
630
|
+
|
572
631
|
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
573
|
-
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
574
632
|
|
575
633
|
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
|
576
634
|
for i in range(num_loop):
|
@@ -642,6 +700,29 @@ def generate_draft_decode_kv_indices(
|
|
642
700
|
tl.store(kv_indptr + zid, base + zid * iters)
|
643
701
|
|
644
702
|
|
703
|
+
@triton.jit
|
704
|
+
def align_evict_mask_to_page_size(
|
705
|
+
seq_lens,
|
706
|
+
evict_mask,
|
707
|
+
page_size: tl.constexpr,
|
708
|
+
num_draft_tokens: tl.constexpr,
|
709
|
+
BLOCK_SIZE: tl.constexpr,
|
710
|
+
):
|
711
|
+
t_range = tl.arange(0, BLOCK_SIZE)
|
712
|
+
|
713
|
+
bid = tl.program_id(axis=0)
|
714
|
+
seq_len = tl.load(seq_lens + bid)
|
715
|
+
io_mask = t_range < num_draft_tokens
|
716
|
+
mask_row = tl.load(evict_mask + bid * num_draft_tokens + t_range, mask=io_mask)
|
717
|
+
|
718
|
+
num_trues = tl.sum(mask_row)
|
719
|
+
num_false = num_draft_tokens - num_trues
|
720
|
+
|
721
|
+
start = (seq_len + num_false - 1) // page_size * page_size - seq_len
|
722
|
+
for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
|
723
|
+
tl.store(evict_mask + bid * num_draft_tokens + i, False)
|
724
|
+
|
725
|
+
|
645
726
|
@torch.compile(dynamic=True)
|
646
727
|
def select_top_k_tokens(
|
647
728
|
i: int,
|
@@ -699,3 +780,34 @@ def fast_topk(values, topk, dim):
|
|
699
780
|
else:
|
700
781
|
# Use topk for efficiency with larger k values
|
701
782
|
return torch.topk(values, topk, dim=dim)
|
783
|
+
|
784
|
+
|
785
|
+
def _generate_simulated_accept_index(
|
786
|
+
accept_index,
|
787
|
+
predict,
|
788
|
+
accept_length,
|
789
|
+
simulate_acc_len,
|
790
|
+
bs,
|
791
|
+
spec_steps,
|
792
|
+
):
|
793
|
+
simulate_acc_len_float = float(simulate_acc_len)
|
794
|
+
simulated_values = torch.normal(
|
795
|
+
mean=simulate_acc_len_float,
|
796
|
+
std=1.0,
|
797
|
+
size=(1,),
|
798
|
+
device="cpu",
|
799
|
+
)
|
800
|
+
# clamp simulated values to be between 1 and self.spec_steps
|
801
|
+
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
|
802
|
+
simulate_acc_len = int(simulated_values.round().item())
|
803
|
+
|
804
|
+
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
|
805
|
+
sim_accept_index = torch.full(
|
806
|
+
(bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
|
807
|
+
)
|
808
|
+
sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
|
809
|
+
simulate_acc_len, device=accept_index.device
|
810
|
+
)
|
811
|
+
accept_length.fill_(simulate_acc_len - 1)
|
812
|
+
predict.fill_(100) # some legit token id
|
813
|
+
return sim_accept_index
|
@@ -11,7 +11,11 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
|
11
11
|
from sglang.srt.layers.dp_attention import disable_dp_size
|
12
12
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
13
13
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
14
|
-
from sglang.srt.managers.schedule_batch import
|
14
|
+
from sglang.srt.managers.schedule_batch import (
|
15
|
+
ScheduleBatch,
|
16
|
+
get_last_loc,
|
17
|
+
global_server_args_dict,
|
18
|
+
)
|
15
19
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
16
20
|
from sglang.srt.model_executor.forward_batch_info import (
|
17
21
|
CaptureHiddenMode,
|
@@ -67,6 +71,7 @@ class EAGLEWorker(TpModelWorker):
|
|
67
71
|
self.gpu_id = gpu_id
|
68
72
|
self.device = server_args.device
|
69
73
|
self.target_worker = target_worker
|
74
|
+
self.page_size = server_args.page_size
|
70
75
|
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
71
76
|
server_args.speculative_algorithm
|
72
77
|
)
|
@@ -145,15 +150,26 @@ class EAGLEWorker(TpModelWorker):
|
|
145
150
|
def init_attention_backend(self):
|
146
151
|
# Create multi-step attn backends and cuda graph runners
|
147
152
|
if self.server_args.attention_backend == "flashinfer":
|
148
|
-
|
149
|
-
|
150
|
-
|
153
|
+
if not global_server_args_dict["use_mla_backend"]:
|
154
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
155
|
+
FlashInferMultiStepDraftBackend,
|
156
|
+
)
|
151
157
|
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
158
|
+
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
159
|
+
self.draft_model_runner,
|
160
|
+
self.topk,
|
161
|
+
self.speculative_num_steps,
|
162
|
+
)
|
163
|
+
else:
|
164
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
165
|
+
FlashInferMLAMultiStepDraftBackend,
|
166
|
+
)
|
167
|
+
|
168
|
+
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
|
169
|
+
self.draft_model_runner,
|
170
|
+
self.topk,
|
171
|
+
self.speculative_num_steps,
|
172
|
+
)
|
157
173
|
self.draft_extend_attn_backend = None
|
158
174
|
self.padded_static_len = self.speculative_num_steps + 1
|
159
175
|
self.has_prefill_wrapper_verify = True
|
@@ -170,19 +186,19 @@ class EAGLEWorker(TpModelWorker):
|
|
170
186
|
self.draft_extend_attn_backend = None
|
171
187
|
self.padded_static_len = self.speculative_num_steps + 1
|
172
188
|
self.has_prefill_wrapper_verify = False
|
173
|
-
elif self.server_args.attention_backend == "
|
174
|
-
from sglang.srt.layers.attention.
|
175
|
-
|
189
|
+
elif self.server_args.attention_backend == "fa3":
|
190
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
191
|
+
FlashAttentionMultiStepBackend,
|
176
192
|
)
|
177
193
|
|
178
|
-
self.draft_attn_backend =
|
194
|
+
self.draft_attn_backend = FlashAttentionMultiStepBackend(
|
179
195
|
self.draft_model_runner,
|
180
196
|
self.topk,
|
181
197
|
self.speculative_num_steps,
|
182
198
|
)
|
183
199
|
self.draft_extend_attn_backend = None
|
184
200
|
self.padded_static_len = self.speculative_num_steps + 1
|
185
|
-
self.has_prefill_wrapper_verify =
|
201
|
+
self.has_prefill_wrapper_verify = False
|
186
202
|
else:
|
187
203
|
raise ValueError(
|
188
204
|
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
@@ -234,14 +250,11 @@ class EAGLEWorker(TpModelWorker):
|
|
234
250
|
"""
|
235
251
|
if batch.forward_mode.is_decode():
|
236
252
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
237
|
-
spec_info
|
253
|
+
spec_info = self.draft(batch)
|
238
254
|
logits_output, verify_output, model_worker_batch = self.verify(
|
239
255
|
batch, spec_info
|
240
256
|
)
|
241
257
|
|
242
|
-
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
|
243
|
-
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
|
244
|
-
|
245
258
|
# If it is None, it means all requests are finished
|
246
259
|
if batch.spec_info.verified_id is not None:
|
247
260
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
@@ -305,9 +318,59 @@ class EAGLEWorker(TpModelWorker):
|
|
305
318
|
)
|
306
319
|
|
307
320
|
# Allocate cache locations
|
308
|
-
|
309
|
-
|
310
|
-
|
321
|
+
if self.page_size == 1:
|
322
|
+
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
|
323
|
+
num_seqs * self.topk * self.speculative_num_steps, backup_state=True
|
324
|
+
)
|
325
|
+
else:
|
326
|
+
if self.topk == 1:
|
327
|
+
prefix_lens = batch.seq_lens
|
328
|
+
seq_lens = prefix_lens + self.speculative_num_steps
|
329
|
+
extend_num_tokens = num_seqs * self.speculative_num_steps
|
330
|
+
else:
|
331
|
+
# In this case, the last partial page needs to be duplicated.
|
332
|
+
# KV cache layout in batch.req_to_token_pool.req_to_token:
|
333
|
+
#
|
334
|
+
# | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
|
335
|
+
# prefix top-k = 0 tok-k = 1 top-k = 2
|
336
|
+
#
|
337
|
+
# "-" means prefix tokens
|
338
|
+
# "x" means speculative draft tokens
|
339
|
+
# "." means padded tokens
|
340
|
+
|
341
|
+
# TODO: fuse these ops
|
342
|
+
prefix_lens = batch.seq_lens
|
343
|
+
last_page_lens = prefix_lens % self.page_size
|
344
|
+
num_new_pages = (
|
345
|
+
last_page_lens + self.speculative_num_steps + self.page_size - 1
|
346
|
+
) // self.page_size
|
347
|
+
seq_lens = (
|
348
|
+
prefix_lens // self.page_size * self.page_size
|
349
|
+
+ num_new_pages * (self.page_size * self.topk)
|
350
|
+
)
|
351
|
+
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
|
352
|
+
raise NotImplementedError(
|
353
|
+
"page_size > 1 and top_k > 1 are not supported."
|
354
|
+
)
|
355
|
+
# TODO: Support page_size > 1 and top_k > 1
|
356
|
+
# 1. Duplicate the KV cache in the last partial page for all top-k segments
|
357
|
+
# 2. Modify generate_draft_decode_kv_indices accordingly
|
358
|
+
|
359
|
+
last_loc = get_last_loc(
|
360
|
+
batch.req_to_token_pool.req_to_token,
|
361
|
+
batch.req_pool_indices,
|
362
|
+
prefix_lens,
|
363
|
+
)
|
364
|
+
out_cache_loc, token_to_kv_pool_state_backup = (
|
365
|
+
batch.alloc_paged_token_slots_extend(
|
366
|
+
prefix_lens,
|
367
|
+
seq_lens,
|
368
|
+
last_loc,
|
369
|
+
extend_num_tokens,
|
370
|
+
backup_state=True,
|
371
|
+
)
|
372
|
+
)
|
373
|
+
|
311
374
|
assign_draft_cache_locs[(num_seqs,)](
|
312
375
|
batch.req_pool_indices,
|
313
376
|
batch.req_to_token_pool.req_to_token,
|
@@ -316,6 +379,7 @@ class EAGLEWorker(TpModelWorker):
|
|
316
379
|
batch.req_to_token_pool.req_to_token.shape[1],
|
317
380
|
self.topk,
|
318
381
|
self.speculative_num_steps,
|
382
|
+
self.page_size,
|
319
383
|
)
|
320
384
|
batch.out_cache_loc = out_cache_loc
|
321
385
|
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
@@ -343,6 +407,8 @@ class EAGLEWorker(TpModelWorker):
|
|
343
407
|
# Run forward steps
|
344
408
|
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
345
409
|
|
410
|
+
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
411
|
+
|
346
412
|
ret = EagleVerifyInput.create(
|
347
413
|
spec_info.verified_id,
|
348
414
|
score_list,
|
@@ -354,7 +420,7 @@ class EAGLEWorker(TpModelWorker):
|
|
354
420
|
self.speculative_num_steps,
|
355
421
|
self.server_args.speculative_num_draft_tokens,
|
356
422
|
)
|
357
|
-
return ret
|
423
|
+
return ret
|
358
424
|
|
359
425
|
def draft_forward(self, forward_batch: ForwardBatch):
|
360
426
|
# Parse args
|
@@ -411,7 +477,7 @@ class EAGLEWorker(TpModelWorker):
|
|
411
477
|
return score_list, token_list, parents_list
|
412
478
|
|
413
479
|
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
414
|
-
spec_info.prepare_for_verify(batch)
|
480
|
+
spec_info.prepare_for_verify(batch, self.page_size)
|
415
481
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
416
482
|
batch.spec_info = spec_info
|
417
483
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -421,7 +487,10 @@ class EAGLEWorker(TpModelWorker):
|
|
421
487
|
self._detect_nan_if_needed(logits_output)
|
422
488
|
spec_info.hidden_states = logits_output.hidden_states
|
423
489
|
res: EagleVerifyOutput = spec_info.verify(
|
424
|
-
batch,
|
490
|
+
batch,
|
491
|
+
logits_output,
|
492
|
+
self.token_to_kv_pool_allocator,
|
493
|
+
self.page_size,
|
425
494
|
)
|
426
495
|
|
427
496
|
# Post process based on verified outputs.
|
@@ -586,5 +655,5 @@ def load_token_map(token_map_path: str) -> List[int]:
|
|
586
655
|
ignore_patterns=["*.bin", "*.safetensors"],
|
587
656
|
)
|
588
657
|
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
589
|
-
hot_token_id = torch.load(token_map_path)
|
658
|
+
hot_token_id = torch.load(token_map_path, weights_only=True)
|
590
659
|
return torch.tensor(hot_token_id, dtype=torch.int32)
|