sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +220 -378
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +208 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import
|
4
|
-
from typing import TYPE_CHECKING, List
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import TYPE_CHECKING, Dict, List
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as F
|
8
8
|
import triton
|
9
9
|
import triton.language as tl
|
10
10
|
|
11
|
-
from sglang.srt.layers.attention.
|
12
|
-
|
13
|
-
|
11
|
+
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
12
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
13
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
14
|
+
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
14
15
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
15
16
|
from sglang.srt.speculative.build_eagle_tree import (
|
16
17
|
build_tree_kernel,
|
@@ -25,7 +26,7 @@ if TYPE_CHECKING:
|
|
25
26
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
26
27
|
|
27
28
|
|
28
|
-
@
|
29
|
+
@dataclass
|
29
30
|
class EagleDraftInput:
|
30
31
|
# The inputs for decode
|
31
32
|
# shape: (b, topk)
|
@@ -46,57 +47,46 @@ class EagleDraftInput:
|
|
46
47
|
kv_indptr: torch.Tensor = None
|
47
48
|
kv_indices: torch.Tensor = None
|
48
49
|
|
50
|
+
# indices of unfinished requests during extend-after-decode
|
51
|
+
# e.g. [0, 2, 3, 4] if only the 1st request is finished
|
52
|
+
keep_indices: List[int] = None
|
53
|
+
|
49
54
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
50
|
-
|
51
|
-
|
52
|
-
|
55
|
+
assert batch.input_ids.numel() == batch.out_cache_loc.shape[0]
|
56
|
+
# Prefill only generate 1 token.
|
57
|
+
assert len(self.verified_id) == len(batch.seq_lens)
|
53
58
|
|
54
59
|
pt = 0
|
55
|
-
for i,
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
if pre_len > 0:
|
61
|
-
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
62
|
-
:pre_len
|
63
|
-
] = req.prefix_indices
|
64
|
-
|
65
|
-
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
66
|
-
out_cache_loc[pt : pt + req.extend_input_len]
|
60
|
+
for i, extend_len in enumerate(batch.extend_lens):
|
61
|
+
input_ids = batch.input_ids[pt : pt + extend_len]
|
62
|
+
batch.input_ids[pt : pt + extend_len] = torch.concat(
|
63
|
+
(input_ids[1:], self.verified_id[i].reshape(1))
|
67
64
|
)
|
68
65
|
|
69
|
-
pt += req.extend_input_len
|
70
|
-
|
71
|
-
# TODO: support batching inputs
|
72
|
-
assert len(batch.extend_lens) == 1
|
73
|
-
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
|
74
|
-
|
75
66
|
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
|
76
|
-
|
67
|
+
assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
|
77
68
|
accept_length_cpu = batch.spec_info.accept_length_cpu
|
78
69
|
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
70
|
+
batch.extend_num_tokens = sum(batch.extend_lens)
|
79
71
|
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
80
|
-
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
81
72
|
seq_lens_cpu = batch.seq_lens.tolist()
|
73
|
+
assert len(batch.req_pool_indices) == len(batch.reqs)
|
82
74
|
|
83
75
|
pt = 0
|
84
76
|
i = 0
|
85
|
-
|
77
|
+
self.keep_indices = []
|
78
|
+
for idx, req in enumerate(batch.reqs):
|
86
79
|
if req.finished():
|
87
80
|
continue
|
81
|
+
self.keep_indices.append(idx)
|
88
82
|
# assert seq_len - pre_len == req.extend_input_len
|
89
83
|
input_len = batch.extend_lens[i]
|
90
84
|
seq_len = seq_lens_cpu[i]
|
91
|
-
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
92
|
-
seq_len - input_len : seq_len
|
93
|
-
] = batch.out_cache_loc[pt : pt + input_len]
|
94
85
|
pt += input_len
|
95
86
|
i += 1
|
96
|
-
assert pt == batch.out_cache_loc.shape[0]
|
97
87
|
|
98
|
-
self.positions = torch.empty_like(self.verified_id)
|
99
|
-
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.
|
88
|
+
self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
|
89
|
+
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
100
90
|
self.accept_length.add_(1)
|
101
91
|
|
102
92
|
create_extend_spec_info[(self.accept_length.numel(),)](
|
@@ -117,14 +107,22 @@ class EagleDraftInput:
|
|
117
107
|
self,
|
118
108
|
req_pool_indices: torch.Tensor,
|
119
109
|
paged_kernel_lens: torch.Tensor,
|
110
|
+
paged_kernel_lens_sum: int,
|
120
111
|
req_to_token: torch.Tensor,
|
121
112
|
):
|
122
113
|
bs = self.accept_length.numel()
|
114
|
+
keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device)
|
115
|
+
req_pool_indices = req_pool_indices[keep_indices]
|
116
|
+
assert req_pool_indices.shape[0] == bs
|
117
|
+
assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0]
|
118
|
+
|
123
119
|
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
124
120
|
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
125
121
|
|
126
122
|
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
127
123
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
124
|
+
|
125
|
+
# TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync.
|
128
126
|
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
129
127
|
|
130
128
|
create_flashinfer_kv_indices_triton[(bs,)](
|
@@ -162,7 +160,21 @@ class EagleDraftInput:
|
|
162
160
|
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
|
163
161
|
|
164
162
|
|
165
|
-
@
|
163
|
+
@dataclass
|
164
|
+
class EagleVerifyOutput:
|
165
|
+
# Draft input batch
|
166
|
+
draft_input: EagleDraftInput
|
167
|
+
# Logit outputs from target worker
|
168
|
+
logits_output: LogitsProcessorOutput
|
169
|
+
# Accepeted token ids including the bonus token
|
170
|
+
verified_id: torch.Tensor
|
171
|
+
# Accepeted token length per sequence in a batch in CPU.
|
172
|
+
accept_length_per_req_cpu: List[int]
|
173
|
+
# Accepeted indices from logits_output.next_token_logits
|
174
|
+
accepeted_indices_cpu: List[int]
|
175
|
+
|
176
|
+
|
177
|
+
@dataclass
|
166
178
|
class EagleVerifyInput:
|
167
179
|
draft_token: torch.Tensor
|
168
180
|
custom_mask: torch.Tensor
|
@@ -267,6 +279,7 @@ class EagleVerifyInput:
|
|
267
279
|
self,
|
268
280
|
req_pool_indices: torch.Tensor,
|
269
281
|
paged_kernel_lens: torch.Tensor,
|
282
|
+
paged_kernel_lens_sum: int,
|
270
283
|
req_to_token: torch.Tensor,
|
271
284
|
):
|
272
285
|
batch_size = len(req_pool_indices)
|
@@ -285,7 +298,11 @@ class EagleVerifyInput:
|
|
285
298
|
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
|
286
299
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
287
300
|
|
288
|
-
kv_indices = torch.empty(
|
301
|
+
kv_indices = torch.empty(
|
302
|
+
paged_kernel_lens_sum + self.draft_token_num * batch_size,
|
303
|
+
dtype=torch.int32,
|
304
|
+
device="cuda",
|
305
|
+
)
|
289
306
|
|
290
307
|
create_flashinfer_kv_indices_triton[(batch_size,)](
|
291
308
|
req_to_token,
|
@@ -298,7 +315,21 @@ class EagleVerifyInput:
|
|
298
315
|
)
|
299
316
|
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
|
300
317
|
|
301
|
-
def verify(
|
318
|
+
def verify(
|
319
|
+
self,
|
320
|
+
batch: ScheduleBatch,
|
321
|
+
logits_output: torch.Tensor,
|
322
|
+
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
323
|
+
) -> torch.Tensor:
|
324
|
+
"""WARNING: This API in-place modifies the states of logits_output
|
325
|
+
|
326
|
+
Verify and find accepted tokens based on logits output and batch
|
327
|
+
(which contains spec decoding information).
|
328
|
+
|
329
|
+
This API updates values inside logits_output based on the accepted
|
330
|
+
tokens. I.e., logits_output.next_token_logits only contains
|
331
|
+
accepeted token logits.
|
332
|
+
"""
|
302
333
|
draft_token = torch.cat(
|
303
334
|
[self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
|
304
335
|
dim=-1,
|
@@ -367,7 +398,6 @@ class EagleVerifyInput:
|
|
367
398
|
|
368
399
|
new_accept_index = []
|
369
400
|
unfinished_index = []
|
370
|
-
finished_extend_len = {} # {rid:accept_length + 1}
|
371
401
|
accept_index_cpu = accept_index.tolist()
|
372
402
|
predict_cpu = predict.tolist()
|
373
403
|
has_finished = False
|
@@ -382,7 +412,6 @@ class EagleVerifyInput:
|
|
382
412
|
id = predict_cpu[idx]
|
383
413
|
# if not found_finished:
|
384
414
|
req.output_ids.append(id)
|
385
|
-
finished_extend_len[req.rid] = j + 1
|
386
415
|
req.check_finished()
|
387
416
|
if req.finished():
|
388
417
|
has_finished = True
|
@@ -400,11 +429,10 @@ class EagleVerifyInput:
|
|
400
429
|
accept_index = accept_index[accept_index != -1]
|
401
430
|
accept_length_cpu = accept_length.tolist()
|
402
431
|
verified_id = predict[accept_index]
|
403
|
-
|
404
432
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
405
433
|
evict_mask[accept_index] = False
|
406
434
|
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
407
|
-
|
435
|
+
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
408
436
|
assign_req_to_token_pool[(bs,)](
|
409
437
|
batch.req_pool_indices,
|
410
438
|
batch.req_to_token_pool.req_to_token,
|
@@ -427,20 +455,16 @@ class EagleVerifyInput:
|
|
427
455
|
]
|
428
456
|
if has_finished:
|
429
457
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
430
|
-
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
|
431
|
-
unfinished_index
|
432
|
-
]
|
433
458
|
else:
|
434
459
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
accept_length_cpu,
|
460
|
+
batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
|
461
|
+
|
462
|
+
return EagleVerifyOutput(
|
463
|
+
draft_input=draft_input,
|
464
|
+
logits_output=logits_output,
|
465
|
+
verified_id=verified_id,
|
466
|
+
accept_length_per_req_cpu=accept_length_cpu,
|
467
|
+
accepeted_indices_cpu=accept_index,
|
444
468
|
)
|
445
469
|
|
446
470
|
|
@@ -456,6 +480,18 @@ def eagle_verify_retrive(
|
|
456
480
|
draft_token_num: tl.constexpr,
|
457
481
|
max_len_upper: tl.constexpr,
|
458
482
|
):
|
483
|
+
"""
|
484
|
+
Args:
|
485
|
+
retrive_index: Pointer to indices of draft tokens
|
486
|
+
accept_mask: Mask indicating which tokens were accepted
|
487
|
+
retrive_cum_len: Cumulative lengths of token sequences in a batch
|
488
|
+
accept_index (out): Accept token indices
|
489
|
+
accept_length (out): Length of accepted tokens per sequence in a batch
|
490
|
+
extract_index (out): Index for last accepted tokens
|
491
|
+
max_len: Maximum length in a batch
|
492
|
+
draft_token_num: Number of tokens speculatively generated
|
493
|
+
max_len_upper An upper bound for token sequence length
|
494
|
+
"""
|
459
495
|
pid = tl.program_id(axis=0)
|
460
496
|
|
461
497
|
retrive_end = tl.load(retrive_cum_len + pid + 1)
|
@@ -649,7 +685,7 @@ def generate_draft_decode_kv_indices(
|
|
649
685
|
tl.store(kv_indptr + zid, base + zid * iters)
|
650
686
|
|
651
687
|
|
652
|
-
@torch.compile
|
688
|
+
@torch.compile(dynamic=True)
|
653
689
|
def select_top_k_tokens(
|
654
690
|
i: int,
|
655
691
|
topk_p: torch.Tensor,
|
@@ -671,13 +707,11 @@ def select_top_k_tokens(
|
|
671
707
|
.unsqueeze(0)
|
672
708
|
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
|
673
709
|
)
|
674
|
-
|
675
710
|
else:
|
676
711
|
# The later decode steps
|
677
712
|
expand_scores = torch.mul(
|
678
713
|
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
|
679
714
|
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
680
|
-
|
681
715
|
topk_cs_p, topk_cs_index = fast_topk(
|
682
716
|
expand_scores.flatten(start_dim=1), topk, dim=-1
|
683
717
|
) # (b, topk)
|