sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -19,16 +19,16 @@ import triton
|
|
19
19
|
import triton.language as tl
|
20
20
|
|
21
21
|
from sglang.global_config import global_config
|
22
|
-
from sglang.srt.layers.attention import AttentionBackend
|
22
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
23
|
+
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
23
24
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
24
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
25
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
26
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
26
27
|
from sglang.srt.utils import is_flashinfer_available
|
27
28
|
|
28
29
|
if TYPE_CHECKING:
|
29
30
|
from sglang.srt.layers.radix_attention import RadixAttention
|
30
31
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
31
|
-
from sglang.srt.speculative.spec_info import SpecInfo
|
32
32
|
|
33
33
|
if is_flashinfer_available():
|
34
34
|
from flashinfer import (
|
@@ -37,7 +37,7 @@ if is_flashinfer_available():
|
|
37
37
|
BatchPrefillWithRaggedKVCacheWrapper,
|
38
38
|
)
|
39
39
|
from flashinfer.cascade import merge_state
|
40
|
-
from flashinfer.
|
40
|
+
from flashinfer.decode import PosEncodingMode
|
41
41
|
|
42
42
|
|
43
43
|
class WrapperDispatch(Enum):
|
@@ -47,9 +47,7 @@ class WrapperDispatch(Enum):
|
|
47
47
|
|
48
48
|
@dataclass
|
49
49
|
class DecodeMetadata:
|
50
|
-
decode_wrappers: List[
|
51
|
-
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
52
|
-
]
|
50
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
53
51
|
|
54
52
|
|
55
53
|
@dataclass
|
@@ -71,6 +69,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
71
69
|
model_runner: ModelRunner,
|
72
70
|
skip_prefill: bool = False,
|
73
71
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
72
|
+
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
74
73
|
):
|
75
74
|
super().__init__()
|
76
75
|
|
@@ -107,12 +106,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
107
106
|
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
108
107
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
109
108
|
|
110
|
-
self.enable_flashinfer_mla = False
|
111
|
-
if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
112
|
-
if global_server_args_dict["enable_flashinfer_mla"]:
|
113
|
-
self.enable_flashinfer_mla = True
|
114
|
-
global_config.enable_flashinfer_mla = True
|
115
|
-
|
116
109
|
# Allocate buffers
|
117
110
|
global global_workspace_buffer
|
118
111
|
if global_workspace_buffer is None:
|
@@ -122,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
122
115
|
device=model_runner.device,
|
123
116
|
)
|
124
117
|
self.workspace_buffer = global_workspace_buffer
|
118
|
+
|
125
119
|
max_bs = model_runner.req_to_token_pool.size
|
126
120
|
if kv_indptr_buf is None:
|
127
121
|
self.kv_indptr = [
|
@@ -130,24 +124,25 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
130
124
|
)
|
131
125
|
for _ in range(self.num_wrappers)
|
132
126
|
]
|
133
|
-
if self.enable_flashinfer_mla:
|
134
|
-
self.qo_indptr = [
|
135
|
-
torch.zeros(
|
136
|
-
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
137
|
-
)
|
138
|
-
for _ in range(self.num_wrappers)
|
139
|
-
]
|
140
127
|
else:
|
141
128
|
assert self.num_wrappers == 1
|
142
129
|
self.kv_indptr = [kv_indptr_buf]
|
143
130
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
131
|
+
if kv_last_page_len_buf is None:
|
132
|
+
self.kv_last_page_len = torch.ones(
|
133
|
+
(max_bs,), dtype=torch.int32, device=model_runner.device
|
134
|
+
)
|
135
|
+
else:
|
136
|
+
assert self.num_wrappers == 1
|
137
|
+
self.kv_last_page_len = kv_last_page_len_buf
|
138
|
+
|
139
|
+
if not self.skip_prefill:
|
140
|
+
self.qo_indptr = [
|
141
|
+
torch.zeros(
|
142
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
143
|
+
)
|
144
|
+
for _ in range(self.num_wrappers)
|
145
|
+
]
|
151
146
|
|
152
147
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
153
148
|
self.workspace_buffer, "NHD"
|
@@ -170,18 +165,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
170
165
|
self.prefill_wrappers_verify.append(
|
171
166
|
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
172
167
|
)
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
180
|
-
self.workspace_buffer,
|
181
|
-
"NHD",
|
182
|
-
use_tensor_cores=self.decode_use_tensor_cores,
|
183
|
-
)
|
168
|
+
|
169
|
+
self.decode_wrappers.append(
|
170
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
171
|
+
self.workspace_buffer,
|
172
|
+
"NHD",
|
173
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
184
174
|
)
|
175
|
+
)
|
185
176
|
|
186
177
|
# Create indices updater
|
187
178
|
if not skip_prefill:
|
@@ -291,37 +282,25 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
291
282
|
seq_lens: torch.Tensor,
|
292
283
|
encoder_lens: Optional[torch.Tensor],
|
293
284
|
forward_mode: ForwardMode,
|
294
|
-
spec_info: Optional[
|
285
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
295
286
|
):
|
296
287
|
if forward_mode.is_decode_or_idle():
|
297
288
|
decode_wrappers = []
|
298
289
|
for i in range(self.num_wrappers):
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
)
|
311
|
-
else:
|
312
|
-
decode_wrappers.append(
|
313
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
314
|
-
self.workspace_buffer,
|
315
|
-
"NHD",
|
316
|
-
use_cuda_graph=True,
|
317
|
-
use_tensor_cores=self.decode_use_tensor_cores,
|
318
|
-
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
319
|
-
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
320
|
-
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
321
|
-
:num_tokens
|
322
|
-
],
|
323
|
-
)
|
290
|
+
decode_wrappers.append(
|
291
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
292
|
+
self.workspace_buffer,
|
293
|
+
"NHD",
|
294
|
+
use_cuda_graph=True,
|
295
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
296
|
+
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
297
|
+
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
298
|
+
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
299
|
+
:num_tokens
|
300
|
+
],
|
324
301
|
)
|
302
|
+
)
|
303
|
+
|
325
304
|
seq_lens_sum = seq_lens.sum().item()
|
326
305
|
self.indices_updater_decode.update(
|
327
306
|
req_pool_indices,
|
@@ -373,7 +352,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
373
352
|
seq_lens_sum: int,
|
374
353
|
encoder_lens: Optional[torch.Tensor],
|
375
354
|
forward_mode: ForwardMode,
|
376
|
-
spec_info: Optional[
|
355
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
356
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
377
357
|
):
|
378
358
|
if forward_mode.is_decode_or_idle():
|
379
359
|
self.indices_updater_decode.update(
|
@@ -410,94 +390,64 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
410
390
|
forward_batch: ForwardBatch,
|
411
391
|
save_kv_cache=True,
|
412
392
|
):
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
393
|
+
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
394
|
+
self._get_wrapper_idx(layer)
|
395
|
+
]
|
396
|
+
cache_loc = (
|
397
|
+
forward_batch.out_cache_loc
|
398
|
+
if not layer.is_cross_attention
|
399
|
+
else forward_batch.encoder_out_cache_loc
|
400
|
+
)
|
419
401
|
|
420
|
-
|
402
|
+
logits_soft_cap = layer.logit_cap
|
403
|
+
|
404
|
+
if not self.forward_metadata.use_ragged:
|
405
|
+
if k is not None:
|
406
|
+
assert v is not None
|
407
|
+
if save_kv_cache:
|
408
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
409
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
410
|
+
)
|
421
411
|
|
422
|
-
|
412
|
+
o = prefill_wrapper_paged.forward(
|
413
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
414
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
415
|
+
causal=not layer.is_cross_attention,
|
416
|
+
sm_scale=layer.scaling,
|
417
|
+
window_left=layer.sliding_window_size,
|
418
|
+
logits_soft_cap=logits_soft_cap,
|
419
|
+
k_scale=layer.k_scale,
|
420
|
+
v_scale=layer.v_scale,
|
421
|
+
)
|
422
|
+
else:
|
423
|
+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
423
424
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
424
425
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
425
|
-
v.view(-1, layer.tp_v_head_num, layer.
|
426
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
426
427
|
causal=True,
|
427
428
|
sm_scale=layer.scaling,
|
428
429
|
logits_soft_cap=logits_soft_cap,
|
429
430
|
)
|
430
431
|
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
layer,
|
436
|
-
cache_loc,
|
437
|
-
k,
|
438
|
-
v,
|
439
|
-
)
|
440
|
-
|
441
|
-
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
442
|
-
else:
|
443
|
-
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
444
|
-
self._get_wrapper_idx(layer)
|
445
|
-
]
|
446
|
-
cache_loc = (
|
447
|
-
forward_batch.out_cache_loc
|
448
|
-
if not layer.is_cross_attention
|
449
|
-
else forward_batch.encoder_out_cache_loc
|
450
|
-
)
|
451
|
-
|
452
|
-
logits_soft_cap = layer.logit_cap
|
453
|
-
|
454
|
-
if not self.forward_metadata.use_ragged:
|
455
|
-
if k is not None:
|
456
|
-
assert v is not None
|
457
|
-
if save_kv_cache:
|
458
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
459
|
-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
460
|
-
)
|
461
|
-
|
462
|
-
o = prefill_wrapper_paged.forward(
|
432
|
+
if self.forward_metadata.extend_no_prefix:
|
433
|
+
o = o1
|
434
|
+
else:
|
435
|
+
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
463
436
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
464
437
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
465
|
-
causal=
|
438
|
+
causal=False,
|
466
439
|
sm_scale=layer.scaling,
|
467
|
-
|
468
|
-
logits_soft_cap=logits_soft_cap,
|
469
|
-
k_scale=layer.k_scale,
|
470
|
-
v_scale=layer.v_scale,
|
440
|
+
logits_soft_cap=layer.logit_cap,
|
471
441
|
)
|
472
|
-
else:
|
473
|
-
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
474
|
-
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
475
|
-
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
476
|
-
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
477
|
-
causal=True,
|
478
|
-
sm_scale=layer.scaling,
|
479
|
-
logits_soft_cap=logits_soft_cap,
|
480
|
-
)
|
481
|
-
|
482
|
-
if self.forward_metadata.extend_no_prefix:
|
483
|
-
o = o1
|
484
|
-
else:
|
485
|
-
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
486
|
-
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
487
|
-
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
488
|
-
causal=False,
|
489
|
-
sm_scale=layer.scaling,
|
490
|
-
logits_soft_cap=layer.logit_cap,
|
491
|
-
)
|
492
442
|
|
493
|
-
|
443
|
+
o, _ = merge_state(o1, s1, o2, s2)
|
494
444
|
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
445
|
+
if save_kv_cache:
|
446
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
447
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
448
|
+
)
|
499
449
|
|
500
|
-
|
450
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
501
451
|
|
502
452
|
def forward_decode(
|
503
453
|
self,
|
@@ -517,45 +467,23 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
517
467
|
else forward_batch.encoder_out_cache_loc
|
518
468
|
)
|
519
469
|
|
520
|
-
if
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
cache_loc,
|
527
|
-
k,
|
528
|
-
v,
|
529
|
-
)
|
530
|
-
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
531
|
-
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
532
|
-
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
|
533
|
-
o = decode_wrapper.run(
|
534
|
-
reshaped_q[:, :, : layer.v_head_dim],
|
535
|
-
reshaped_q[:, :, layer.v_head_dim :],
|
536
|
-
reshaped_k[:, :, : layer.v_head_dim],
|
537
|
-
reshaped_k[:, :, layer.v_head_dim :],
|
538
|
-
)
|
539
|
-
|
540
|
-
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
541
|
-
else:
|
542
|
-
if k is not None:
|
543
|
-
assert v is not None
|
544
|
-
if save_kv_cache:
|
545
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
546
|
-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
547
|
-
)
|
470
|
+
if k is not None:
|
471
|
+
assert v is not None
|
472
|
+
if save_kv_cache:
|
473
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
474
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
475
|
+
)
|
548
476
|
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
477
|
+
o = decode_wrapper.forward(
|
478
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
479
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
480
|
+
sm_scale=layer.scaling,
|
481
|
+
logits_soft_cap=layer.logit_cap,
|
482
|
+
k_scale=layer.k_scale,
|
483
|
+
v_scale=layer.v_scale,
|
484
|
+
)
|
557
485
|
|
558
|
-
|
486
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
559
487
|
|
560
488
|
def _get_wrapper_idx(self, layer: RadixAttention):
|
561
489
|
if self.num_wrappers == 1:
|
@@ -603,11 +531,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
603
531
|
req_pool_indices: torch.Tensor,
|
604
532
|
seq_lens: torch.Tensor,
|
605
533
|
seq_lens_sum: int,
|
606
|
-
decode_wrappers: List[
|
607
|
-
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
608
|
-
],
|
534
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
609
535
|
encoder_lens: Optional[torch.Tensor],
|
610
|
-
spec_info: Optional[
|
536
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
611
537
|
):
|
612
538
|
# Keep the signature for type checking. It will be assigned during runtime.
|
613
539
|
raise NotImplementedError()
|
@@ -617,11 +543,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
617
543
|
req_pool_indices: torch.Tensor,
|
618
544
|
seq_lens: torch.Tensor,
|
619
545
|
seq_lens_sum: int,
|
620
|
-
decode_wrappers: List[
|
621
|
-
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
622
|
-
],
|
546
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
623
547
|
encoder_lens: Optional[torch.Tensor],
|
624
|
-
spec_info: Optional[
|
548
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
625
549
|
):
|
626
550
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
627
551
|
self.call_begin_forward(
|
@@ -641,7 +565,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
641
565
|
seq_lens_sum: int,
|
642
566
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
643
567
|
encoder_lens: Optional[torch.Tensor],
|
644
|
-
spec_info: Optional[
|
568
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
645
569
|
):
|
646
570
|
for wrapper_id in range(2):
|
647
571
|
if wrapper_id == 0:
|
@@ -675,7 +599,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
675
599
|
seq_lens_sum: int,
|
676
600
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
677
601
|
encoder_lens: Optional[torch.Tensor],
|
678
|
-
spec_info: Optional[
|
602
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
679
603
|
):
|
680
604
|
for wrapper_id in range(2):
|
681
605
|
if wrapper_id == 0:
|
@@ -700,15 +624,13 @@ class FlashInferIndicesUpdaterDecode:
|
|
700
624
|
|
701
625
|
def call_begin_forward(
|
702
626
|
self,
|
703
|
-
wrapper:
|
704
|
-
BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
705
|
-
],
|
627
|
+
wrapper: BatchDecodeWithPagedKVCacheWrapper,
|
706
628
|
req_pool_indices: torch.Tensor,
|
707
629
|
paged_kernel_lens: torch.Tensor,
|
708
630
|
paged_kernel_lens_sum: int,
|
709
631
|
kv_indptr: torch.Tensor,
|
710
632
|
kv_start_idx: torch.Tensor,
|
711
|
-
spec_info: Optional[
|
633
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
712
634
|
):
|
713
635
|
if spec_info is None:
|
714
636
|
bs = len(req_pool_indices)
|
@@ -727,40 +649,21 @@ class FlashInferIndicesUpdaterDecode:
|
|
727
649
|
self.req_to_token.shape[1],
|
728
650
|
)
|
729
651
|
else:
|
652
|
+
assert isinstance(spec_info, EagleDraftInput)
|
730
653
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
731
654
|
bs = kv_indptr.shape[0] - 1
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
64,
|
745
|
-
1,
|
746
|
-
False,
|
747
|
-
sm_scale,
|
748
|
-
self.data_type,
|
749
|
-
self.data_type,
|
750
|
-
)
|
751
|
-
else:
|
752
|
-
wrapper.begin_forward(
|
753
|
-
kv_indptr,
|
754
|
-
kv_indices,
|
755
|
-
self.kv_last_page_len[:bs],
|
756
|
-
self.num_qo_heads,
|
757
|
-
self.num_kv_heads,
|
758
|
-
self.head_dim,
|
759
|
-
1,
|
760
|
-
data_type=self.data_type,
|
761
|
-
q_data_type=self.q_data_type,
|
762
|
-
non_blocking=True,
|
763
|
-
)
|
655
|
+
wrapper.begin_forward(
|
656
|
+
kv_indptr,
|
657
|
+
kv_indices,
|
658
|
+
self.kv_last_page_len[:bs],
|
659
|
+
self.num_qo_heads,
|
660
|
+
self.num_kv_heads,
|
661
|
+
self.head_dim,
|
662
|
+
1,
|
663
|
+
data_type=self.data_type,
|
664
|
+
q_data_type=self.q_data_type,
|
665
|
+
non_blocking=True,
|
666
|
+
)
|
764
667
|
|
765
668
|
|
766
669
|
class FlashInferIndicesUpdaterPrefill:
|
@@ -803,7 +706,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
803
706
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
804
707
|
use_ragged: bool,
|
805
708
|
encoder_lens: Optional[torch.Tensor],
|
806
|
-
spec_info: Optional[
|
709
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
807
710
|
):
|
808
711
|
# Keep the signature for type checking. It will be assigned during runtime.
|
809
712
|
raise NotImplementedError()
|
@@ -817,7 +720,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
817
720
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
818
721
|
use_ragged: bool,
|
819
722
|
encoder_lens: Optional[torch.Tensor],
|
820
|
-
spec_info: Optional[
|
723
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
821
724
|
):
|
822
725
|
if use_ragged:
|
823
726
|
paged_kernel_lens = prefix_lens
|
@@ -850,7 +753,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
850
753
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
851
754
|
use_ragged: bool,
|
852
755
|
encoder_lens: Optional[torch.Tensor],
|
853
|
-
spec_info: Optional[
|
756
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
854
757
|
):
|
855
758
|
for wrapper_id in range(2):
|
856
759
|
if wrapper_id == 0:
|
@@ -891,7 +794,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
891
794
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
892
795
|
use_ragged: bool,
|
893
796
|
encoder_lens: Optional[torch.Tensor],
|
894
|
-
spec_info: Optional[
|
797
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
895
798
|
):
|
896
799
|
for wrapper_id in range(2):
|
897
800
|
if wrapper_id == 0:
|
@@ -933,10 +836,11 @@ class FlashInferIndicesUpdaterPrefill:
|
|
933
836
|
kv_indptr: torch.Tensor,
|
934
837
|
qo_indptr: torch.Tensor,
|
935
838
|
use_ragged: bool,
|
936
|
-
spec_info: Optional[
|
839
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
937
840
|
):
|
938
|
-
bs = len(
|
841
|
+
bs = len(seq_lens)
|
939
842
|
if spec_info is None:
|
843
|
+
assert len(seq_lens) == len(req_pool_indices)
|
940
844
|
# Normal extend
|
941
845
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
942
846
|
kv_indptr = kv_indptr[: bs + 1]
|
@@ -959,52 +863,49 @@ class FlashInferIndicesUpdaterPrefill:
|
|
959
863
|
qo_indptr = qo_indptr[: bs + 1]
|
960
864
|
custom_mask = None
|
961
865
|
else:
|
866
|
+
assert isinstance(spec_info, EagleDraftInput) or isinstance(
|
867
|
+
spec_info, EagleVerifyInput
|
868
|
+
)
|
962
869
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
963
870
|
spec_info.generate_attn_arg_prefill(
|
964
871
|
req_pool_indices,
|
965
872
|
paged_kernel_lens,
|
873
|
+
paged_kernel_lens_sum,
|
966
874
|
self.req_to_token,
|
967
875
|
)
|
968
876
|
)
|
969
877
|
|
970
878
|
# extend part
|
971
879
|
if use_ragged:
|
972
|
-
|
973
|
-
|
974
|
-
qo_indptr=qo_indptr,
|
975
|
-
kv_indptr=qo_indptr,
|
976
|
-
num_qo_heads=self.num_qo_heads,
|
977
|
-
num_kv_heads=self.num_kv_heads,
|
978
|
-
head_dim_qk=192,
|
979
|
-
head_dim_vo=128,
|
980
|
-
q_data_type=self.q_data_type,
|
981
|
-
)
|
982
|
-
else:
|
983
|
-
wrapper_ragged.begin_forward(
|
984
|
-
qo_indptr,
|
985
|
-
qo_indptr,
|
986
|
-
self.num_qo_heads,
|
987
|
-
self.num_kv_heads,
|
988
|
-
self.head_dim,
|
989
|
-
q_data_type=self.q_data_type,
|
990
|
-
)
|
991
|
-
|
992
|
-
if not global_config.enable_flashinfer_mla:
|
993
|
-
# cached part
|
994
|
-
wrapper_paged.begin_forward(
|
880
|
+
wrapper_ragged.begin_forward(
|
881
|
+
qo_indptr,
|
995
882
|
qo_indptr,
|
996
|
-
kv_indptr,
|
997
|
-
kv_indices,
|
998
|
-
self.kv_last_page_len[:bs],
|
999
883
|
self.num_qo_heads,
|
1000
884
|
self.num_kv_heads,
|
1001
885
|
self.head_dim,
|
1002
|
-
1,
|
1003
886
|
q_data_type=self.q_data_type,
|
1004
|
-
custom_mask=custom_mask,
|
1005
|
-
non_blocking=True,
|
1006
887
|
)
|
1007
888
|
|
889
|
+
# cached part
|
890
|
+
wrapper_paged.begin_forward(
|
891
|
+
qo_indptr,
|
892
|
+
kv_indptr,
|
893
|
+
kv_indices,
|
894
|
+
self.kv_last_page_len[:bs],
|
895
|
+
self.num_qo_heads,
|
896
|
+
self.num_kv_heads,
|
897
|
+
self.head_dim,
|
898
|
+
1,
|
899
|
+
q_data_type=self.q_data_type,
|
900
|
+
custom_mask=custom_mask,
|
901
|
+
non_blocking=True,
|
902
|
+
)
|
903
|
+
|
904
|
+
|
905
|
+
# Use as a fast path to override the indptr in flashinfer's plan function
|
906
|
+
# This is used to remove some host-to-device copy overhead.
|
907
|
+
global global_override_indptr_cpu
|
908
|
+
|
1008
909
|
|
1009
910
|
class FlashInferMultiStepDraftBackend:
|
1010
911
|
"""
|
@@ -1023,7 +924,8 @@ class FlashInferMultiStepDraftBackend:
|
|
1023
924
|
self.topk = topk
|
1024
925
|
self.speculative_num_steps = speculative_num_steps
|
1025
926
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
1026
|
-
|
927
|
+
|
928
|
+
max_bs = model_runner.req_to_token_pool.size * self.topk
|
1027
929
|
self.kv_indptr = torch.zeros(
|
1028
930
|
(
|
1029
931
|
self.speculative_num_steps,
|
@@ -1032,6 +934,9 @@ class FlashInferMultiStepDraftBackend:
|
|
1032
934
|
dtype=torch.int32,
|
1033
935
|
device=model_runner.device,
|
1034
936
|
)
|
937
|
+
self.kv_last_page_len = torch.ones(
|
938
|
+
(max_bs,), dtype=torch.int32, device=model_runner.device
|
939
|
+
)
|
1035
940
|
self.attn_backends = []
|
1036
941
|
for i in range(self.speculative_num_steps):
|
1037
942
|
self.attn_backends.append(
|
@@ -1039,9 +944,12 @@ class FlashInferMultiStepDraftBackend:
|
|
1039
944
|
model_runner,
|
1040
945
|
skip_prefill=True,
|
1041
946
|
kv_indptr_buf=self.kv_indptr[i],
|
947
|
+
kv_last_page_len_buf=self.kv_last_page_len,
|
1042
948
|
)
|
1043
949
|
)
|
950
|
+
|
1044
951
|
self.max_context_len = self.attn_backends[0].max_context_len
|
952
|
+
|
1045
953
|
# Cached variables for generate_draft_decode_kv_indices
|
1046
954
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
1047
955
|
|
@@ -1071,13 +979,23 @@ class FlashInferMultiStepDraftBackend:
|
|
1071
979
|
triton.next_power_of_2(bs),
|
1072
980
|
)
|
1073
981
|
|
982
|
+
assert forward_batch.spec_info is not None
|
983
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
984
|
+
|
985
|
+
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
|
986
|
+
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
|
987
|
+
global global_override_indptr_cpu
|
988
|
+
|
1074
989
|
for i in range(self.speculative_num_steps - 1):
|
1075
990
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
1076
991
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
1077
992
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
1078
993
|
]
|
994
|
+
global_override_indptr_cpu = indptr_cpu_whole[i]
|
1079
995
|
call_fn(i, forward_batch)
|
1080
996
|
|
997
|
+
global_override_indptr_cpu = None
|
998
|
+
|
1081
999
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
1082
1000
|
kv_indices = torch.zeros(
|
1083
1001
|
(
|
@@ -1089,6 +1007,8 @@ class FlashInferMultiStepDraftBackend:
|
|
1089
1007
|
)
|
1090
1008
|
|
1091
1009
|
def call_fn(i, forward_batch):
|
1010
|
+
assert forward_batch.spec_info is not None
|
1011
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
1092
1012
|
forward_batch.spec_info.kv_indptr = (
|
1093
1013
|
forward_batch.spec_info.kv_indptr.clone()
|
1094
1014
|
)
|
@@ -1105,6 +1025,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1105
1025
|
dtype=torch.int32,
|
1106
1026
|
device="cuda",
|
1107
1027
|
)
|
1028
|
+
|
1108
1029
|
for i in range(self.speculative_num_steps):
|
1109
1030
|
self.attn_backends[i].init_cuda_graph_state(
|
1110
1031
|
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
@@ -1138,48 +1059,12 @@ class FlashInferMultiStepDraftBackend:
|
|
1138
1059
|
encoder_lens=None,
|
1139
1060
|
forward_mode=ForwardMode.DECODE,
|
1140
1061
|
spec_info=forward_batch.spec_info,
|
1062
|
+
seq_lens_cpu=None,
|
1141
1063
|
)
|
1142
1064
|
|
1143
1065
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
1144
1066
|
|
1145
1067
|
|
1146
|
-
@triton.jit
|
1147
|
-
def create_flashinfer_kv_indices_triton(
|
1148
|
-
req_to_token_ptr, # [max_batch, max_context_len]
|
1149
|
-
req_pool_indices_ptr,
|
1150
|
-
page_kernel_lens_ptr,
|
1151
|
-
kv_indptr,
|
1152
|
-
kv_start_idx,
|
1153
|
-
kv_indices_ptr,
|
1154
|
-
req_to_token_ptr_stride: tl.constexpr,
|
1155
|
-
):
|
1156
|
-
BLOCK_SIZE: tl.constexpr = 512
|
1157
|
-
pid = tl.program_id(axis=0)
|
1158
|
-
|
1159
|
-
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
1160
|
-
kv_indices_offset = tl.load(kv_indptr + pid)
|
1161
|
-
|
1162
|
-
kv_start = 0
|
1163
|
-
kv_end = 0
|
1164
|
-
if kv_start_idx:
|
1165
|
-
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
1166
|
-
kv_end = kv_start
|
1167
|
-
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
1168
|
-
|
1169
|
-
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
1170
|
-
for i in range(num_loop):
|
1171
|
-
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
1172
|
-
mask = offset < kv_end - kv_start
|
1173
|
-
data = tl.load(
|
1174
|
-
req_to_token_ptr
|
1175
|
-
+ req_pool_index * req_to_token_ptr_stride
|
1176
|
-
+ kv_start
|
1177
|
-
+ offset,
|
1178
|
-
mask=mask,
|
1179
|
-
)
|
1180
|
-
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
1181
|
-
|
1182
|
-
|
1183
1068
|
def should_use_tensor_core(
|
1184
1069
|
kv_cache_dtype: torch.dtype,
|
1185
1070
|
num_attention_heads: int,
|
@@ -1201,6 +1086,21 @@ def should_use_tensor_core(
|
|
1201
1086
|
if env_override is not None:
|
1202
1087
|
return env_override.lower() == "true"
|
1203
1088
|
|
1089
|
+
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
1090
|
+
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
1091
|
+
try:
|
1092
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
1093
|
+
|
1094
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
1095
|
+
num_attention_heads,
|
1096
|
+
num_kv_heads,
|
1097
|
+
):
|
1098
|
+
return True
|
1099
|
+
else:
|
1100
|
+
return False
|
1101
|
+
except (ImportError, AttributeError):
|
1102
|
+
pass
|
1103
|
+
|
1204
1104
|
# Calculate GQA group size
|
1205
1105
|
gqa_group_size = num_attention_heads // num_kv_heads
|
1206
1106
|
|
@@ -1230,12 +1130,18 @@ def fast_decode_plan(
|
|
1230
1130
|
sm_scale: Optional[float] = None,
|
1231
1131
|
rope_scale: Optional[float] = None,
|
1232
1132
|
rope_theta: Optional[float] = None,
|
1233
|
-
|
1133
|
+
non_blocking: bool = True,
|
1234
1134
|
) -> None:
|
1235
|
-
"""
|
1135
|
+
"""
|
1136
|
+
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
1137
|
+
Modifications:
|
1138
|
+
- Remove unnecessary device-to-device copy for the cuda graph buffers.
|
1139
|
+
- Remove unnecessary host-to-device copy for the metadata buffers.
|
1140
|
+
"""
|
1236
1141
|
batch_size = len(last_page_len)
|
1237
1142
|
if logits_soft_cap is None:
|
1238
1143
|
logits_soft_cap = 0.0
|
1144
|
+
|
1239
1145
|
if self.is_cuda_graph_enabled:
|
1240
1146
|
if batch_size != self._fixed_batch_size:
|
1241
1147
|
raise ValueError(
|
@@ -1248,13 +1154,19 @@ def fast_decode_plan(
|
|
1248
1154
|
raise ValueError(
|
1249
1155
|
"The size of indices should be less than or equal to the allocated buffer"
|
1250
1156
|
)
|
1157
|
+
# Skip these copies
|
1158
|
+
# self._paged_kv_indptr_buf.copy_(indptr)
|
1159
|
+
# self._paged_kv_indices_buf[: len(indices)] = indices
|
1160
|
+
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
1251
1161
|
else:
|
1252
1162
|
self._paged_kv_indptr_buf = indptr
|
1253
1163
|
self._paged_kv_indices_buf = indices
|
1254
1164
|
self._paged_kv_last_page_len_buf = last_page_len
|
1165
|
+
|
1255
1166
|
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
1256
1167
|
if not q_data_type:
|
1257
1168
|
q_data_type = data_type
|
1169
|
+
|
1258
1170
|
if not hasattr(self, "empty_q_data"):
|
1259
1171
|
self.empty_q_data = torch.empty(
|
1260
1172
|
0,
|
@@ -1271,6 +1183,7 @@ def fast_decode_plan(
|
|
1271
1183
|
),
|
1272
1184
|
)
|
1273
1185
|
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
1186
|
+
|
1274
1187
|
empty_q_data = self.empty_q_data
|
1275
1188
|
empty_kv_cache = self.empty_kv_cache
|
1276
1189
|
stream = torch.cuda.current_stream()
|