sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import
|
3
|
+
from dataclasses import dataclass
|
4
4
|
from typing import TYPE_CHECKING, List
|
5
5
|
|
6
6
|
import torch
|
@@ -8,9 +8,10 @@ import torch.nn.functional as F
|
|
8
8
|
import triton
|
9
9
|
import triton.language as tl
|
10
10
|
|
11
|
-
from sglang.srt.layers.attention.
|
12
|
-
|
13
|
-
|
11
|
+
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
12
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
13
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
14
|
+
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
14
15
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
15
16
|
from sglang.srt.speculative.build_eagle_tree import (
|
16
17
|
build_tree_kernel,
|
@@ -25,7 +26,7 @@ if TYPE_CHECKING:
|
|
25
26
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
26
27
|
|
27
28
|
|
28
|
-
@
|
29
|
+
@dataclass
|
29
30
|
class EagleDraftInput:
|
30
31
|
# The inputs for decode
|
31
32
|
# shape: (b, topk)
|
@@ -46,57 +47,47 @@ class EagleDraftInput:
|
|
46
47
|
kv_indptr: torch.Tensor = None
|
47
48
|
kv_indices: torch.Tensor = None
|
48
49
|
|
50
|
+
# indices of unfinished requests during extend-after-decode
|
51
|
+
# e.g. [0, 2, 3, 4] if only the 1st request is finished
|
52
|
+
keep_indices: List[int] = None
|
53
|
+
|
49
54
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
50
|
-
|
51
|
-
|
52
|
-
|
55
|
+
assert batch.input_ids.numel() == batch.out_cache_loc.shape[0]
|
56
|
+
# Prefill only generate 1 token.
|
57
|
+
assert len(self.verified_id) == len(batch.seq_lens)
|
53
58
|
|
54
59
|
pt = 0
|
55
|
-
for i,
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
if pre_len > 0:
|
61
|
-
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
62
|
-
:pre_len
|
63
|
-
] = req.prefix_indices
|
64
|
-
|
65
|
-
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
66
|
-
out_cache_loc[pt : pt + req.extend_input_len]
|
60
|
+
for i, extend_len in enumerate(batch.extend_lens):
|
61
|
+
input_ids = batch.input_ids[pt : pt + extend_len]
|
62
|
+
batch.input_ids[pt : pt + extend_len] = torch.concat(
|
63
|
+
(input_ids[1:], self.verified_id[i].reshape(1))
|
67
64
|
)
|
68
|
-
|
69
|
-
pt += req.extend_input_len
|
70
|
-
|
71
|
-
# TODO: support batching inputs
|
72
|
-
assert len(batch.extend_lens) == 1
|
73
|
-
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
|
65
|
+
pt += extend_len
|
74
66
|
|
75
67
|
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
|
76
|
-
|
68
|
+
assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
|
77
69
|
accept_length_cpu = batch.spec_info.accept_length_cpu
|
78
70
|
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
71
|
+
batch.extend_num_tokens = sum(batch.extend_lens)
|
79
72
|
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
80
|
-
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
81
73
|
seq_lens_cpu = batch.seq_lens.tolist()
|
74
|
+
assert len(batch.req_pool_indices) == len(batch.reqs)
|
82
75
|
|
83
76
|
pt = 0
|
84
77
|
i = 0
|
85
|
-
|
78
|
+
self.keep_indices = []
|
79
|
+
for idx, req in enumerate(batch.reqs):
|
86
80
|
if req.finished():
|
87
81
|
continue
|
82
|
+
self.keep_indices.append(idx)
|
88
83
|
# assert seq_len - pre_len == req.extend_input_len
|
89
84
|
input_len = batch.extend_lens[i]
|
90
85
|
seq_len = seq_lens_cpu[i]
|
91
|
-
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
92
|
-
seq_len - input_len : seq_len
|
93
|
-
] = batch.out_cache_loc[pt : pt + input_len]
|
94
86
|
pt += input_len
|
95
87
|
i += 1
|
96
|
-
assert pt == batch.out_cache_loc.shape[0]
|
97
88
|
|
98
|
-
self.positions = torch.empty_like(self.verified_id)
|
99
|
-
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.
|
89
|
+
self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
|
90
|
+
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
100
91
|
self.accept_length.add_(1)
|
101
92
|
|
102
93
|
create_extend_spec_info[(self.accept_length.numel(),)](
|
@@ -117,14 +108,22 @@ class EagleDraftInput:
|
|
117
108
|
self,
|
118
109
|
req_pool_indices: torch.Tensor,
|
119
110
|
paged_kernel_lens: torch.Tensor,
|
111
|
+
paged_kernel_lens_sum: int,
|
120
112
|
req_to_token: torch.Tensor,
|
121
113
|
):
|
122
114
|
bs = self.accept_length.numel()
|
115
|
+
keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device)
|
116
|
+
req_pool_indices = req_pool_indices[keep_indices]
|
117
|
+
assert req_pool_indices.shape[0] == bs
|
118
|
+
assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0]
|
119
|
+
|
123
120
|
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
124
121
|
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
125
122
|
|
126
123
|
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
127
124
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
125
|
+
|
126
|
+
# TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync.
|
128
127
|
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
129
128
|
|
130
129
|
create_flashinfer_kv_indices_triton[(bs,)](
|
@@ -162,7 +161,21 @@ class EagleDraftInput:
|
|
162
161
|
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
|
163
162
|
|
164
163
|
|
165
|
-
@
|
164
|
+
@dataclass
|
165
|
+
class EagleVerifyOutput:
|
166
|
+
# Draft input batch
|
167
|
+
draft_input: EagleDraftInput
|
168
|
+
# Logit outputs from target worker
|
169
|
+
logits_output: LogitsProcessorOutput
|
170
|
+
# Accepeted token ids including the bonus token
|
171
|
+
verified_id: torch.Tensor
|
172
|
+
# Accepeted token length per sequence in a batch in CPU.
|
173
|
+
accept_length_per_req_cpu: List[int]
|
174
|
+
# Accepeted indices from logits_output.next_token_logits
|
175
|
+
accepeted_indices_cpu: List[int]
|
176
|
+
|
177
|
+
|
178
|
+
@dataclass
|
166
179
|
class EagleVerifyInput:
|
167
180
|
draft_token: torch.Tensor
|
168
181
|
custom_mask: torch.Tensor
|
@@ -267,6 +280,7 @@ class EagleVerifyInput:
|
|
267
280
|
self,
|
268
281
|
req_pool_indices: torch.Tensor,
|
269
282
|
paged_kernel_lens: torch.Tensor,
|
283
|
+
paged_kernel_lens_sum: int,
|
270
284
|
req_to_token: torch.Tensor,
|
271
285
|
):
|
272
286
|
batch_size = len(req_pool_indices)
|
@@ -285,7 +299,11 @@ class EagleVerifyInput:
|
|
285
299
|
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
|
286
300
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
287
301
|
|
288
|
-
kv_indices = torch.empty(
|
302
|
+
kv_indices = torch.empty(
|
303
|
+
paged_kernel_lens_sum + self.draft_token_num * batch_size,
|
304
|
+
dtype=torch.int32,
|
305
|
+
device="cuda",
|
306
|
+
)
|
289
307
|
|
290
308
|
create_flashinfer_kv_indices_triton[(batch_size,)](
|
291
309
|
req_to_token,
|
@@ -298,7 +316,21 @@ class EagleVerifyInput:
|
|
298
316
|
)
|
299
317
|
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
|
300
318
|
|
301
|
-
def verify(
|
319
|
+
def verify(
|
320
|
+
self,
|
321
|
+
batch: ScheduleBatch,
|
322
|
+
logits_output: torch.Tensor,
|
323
|
+
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
324
|
+
) -> torch.Tensor:
|
325
|
+
"""WARNING: This API in-place modifies the states of logits_output
|
326
|
+
|
327
|
+
Verify and find accepted tokens based on logits output and batch
|
328
|
+
(which contains spec decoding information).
|
329
|
+
|
330
|
+
This API updates values inside logits_output based on the accepted
|
331
|
+
tokens. I.e., logits_output.next_token_logits only contains
|
332
|
+
accepeted token logits.
|
333
|
+
"""
|
302
334
|
draft_token = torch.cat(
|
303
335
|
[self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
|
304
336
|
dim=-1,
|
@@ -367,7 +399,6 @@ class EagleVerifyInput:
|
|
367
399
|
|
368
400
|
new_accept_index = []
|
369
401
|
unfinished_index = []
|
370
|
-
finished_extend_len = {} # {rid:accept_length + 1}
|
371
402
|
accept_index_cpu = accept_index.tolist()
|
372
403
|
predict_cpu = predict.tolist()
|
373
404
|
has_finished = False
|
@@ -382,7 +413,6 @@ class EagleVerifyInput:
|
|
382
413
|
id = predict_cpu[idx]
|
383
414
|
# if not found_finished:
|
384
415
|
req.output_ids.append(id)
|
385
|
-
finished_extend_len[req.rid] = j + 1
|
386
416
|
req.check_finished()
|
387
417
|
if req.finished():
|
388
418
|
has_finished = True
|
@@ -400,11 +430,10 @@ class EagleVerifyInput:
|
|
400
430
|
accept_index = accept_index[accept_index != -1]
|
401
431
|
accept_length_cpu = accept_length.tolist()
|
402
432
|
verified_id = predict[accept_index]
|
403
|
-
|
404
433
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
405
434
|
evict_mask[accept_index] = False
|
406
435
|
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
407
|
-
|
436
|
+
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
408
437
|
assign_req_to_token_pool[(bs,)](
|
409
438
|
batch.req_pool_indices,
|
410
439
|
batch.req_to_token_pool.req_to_token,
|
@@ -427,20 +456,16 @@ class EagleVerifyInput:
|
|
427
456
|
]
|
428
457
|
if has_finished:
|
429
458
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
430
|
-
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
|
431
|
-
unfinished_index
|
432
|
-
]
|
433
459
|
else:
|
434
460
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
accept_length_cpu,
|
461
|
+
batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
|
462
|
+
|
463
|
+
return EagleVerifyOutput(
|
464
|
+
draft_input=draft_input,
|
465
|
+
logits_output=logits_output,
|
466
|
+
verified_id=verified_id,
|
467
|
+
accept_length_per_req_cpu=accept_length_cpu,
|
468
|
+
accepeted_indices_cpu=accept_index,
|
444
469
|
)
|
445
470
|
|
446
471
|
|
@@ -456,6 +481,18 @@ def eagle_verify_retrive(
|
|
456
481
|
draft_token_num: tl.constexpr,
|
457
482
|
max_len_upper: tl.constexpr,
|
458
483
|
):
|
484
|
+
"""
|
485
|
+
Args:
|
486
|
+
retrive_index: Pointer to indices of draft tokens
|
487
|
+
accept_mask: Mask indicating which tokens were accepted
|
488
|
+
retrive_cum_len: Cumulative lengths of token sequences in a batch
|
489
|
+
accept_index (out): Accept token indices
|
490
|
+
accept_length (out): Length of accepted tokens per sequence in a batch
|
491
|
+
extract_index (out): Index for last accepted tokens
|
492
|
+
max_len: Maximum length in a batch
|
493
|
+
draft_token_num: Number of tokens speculatively generated
|
494
|
+
max_len_upper An upper bound for token sequence length
|
495
|
+
"""
|
459
496
|
pid = tl.program_id(axis=0)
|
460
497
|
|
461
498
|
retrive_end = tl.load(retrive_cum_len + pid + 1)
|
@@ -649,7 +686,7 @@ def generate_draft_decode_kv_indices(
|
|
649
686
|
tl.store(kv_indptr + zid, base + zid * iters)
|
650
687
|
|
651
688
|
|
652
|
-
@torch.compile
|
689
|
+
@torch.compile(dynamic=True)
|
653
690
|
def select_top_k_tokens(
|
654
691
|
i: int,
|
655
692
|
topk_p: torch.Tensor,
|
@@ -671,13 +708,11 @@ def select_top_k_tokens(
|
|
671
708
|
.unsqueeze(0)
|
672
709
|
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
|
673
710
|
)
|
674
|
-
|
675
711
|
else:
|
676
712
|
# The later decode steps
|
677
713
|
expand_scores = torch.mul(
|
678
714
|
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
|
679
715
|
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
680
|
-
|
681
716
|
topk_cs_p, topk_cs_index = fast_topk(
|
682
717
|
expand_scores.flatten(start_dim=1), topk, dim=-1
|
683
718
|
) # (b, topk)
|