sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +220 -378
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +208 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -19,16 +19,16 @@ import triton
|
|
19
19
|
import triton.language as tl
|
20
20
|
|
21
21
|
from sglang.global_config import global_config
|
22
|
-
from sglang.srt.layers.attention import AttentionBackend
|
22
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
23
|
+
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
23
24
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
24
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
25
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
26
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
26
27
|
from sglang.srt.utils import is_flashinfer_available
|
27
28
|
|
28
29
|
if TYPE_CHECKING:
|
29
30
|
from sglang.srt.layers.radix_attention import RadixAttention
|
30
31
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
31
|
-
from sglang.srt.speculative.spec_info import SpecInfo
|
32
32
|
|
33
33
|
if is_flashinfer_available():
|
34
34
|
from flashinfer import (
|
@@ -37,7 +37,7 @@ if is_flashinfer_available():
|
|
37
37
|
BatchPrefillWithRaggedKVCacheWrapper,
|
38
38
|
)
|
39
39
|
from flashinfer.cascade import merge_state
|
40
|
-
from flashinfer.
|
40
|
+
from flashinfer.decode import PosEncodingMode
|
41
41
|
|
42
42
|
|
43
43
|
class WrapperDispatch(Enum):
|
@@ -47,16 +47,12 @@ class WrapperDispatch(Enum):
|
|
47
47
|
|
48
48
|
@dataclass
|
49
49
|
class DecodeMetadata:
|
50
|
-
decode_wrappers: List[
|
51
|
-
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
52
|
-
]
|
50
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
53
51
|
|
54
52
|
|
55
53
|
@dataclass
|
56
54
|
class PrefillMetadata:
|
57
|
-
prefill_wrappers: List[
|
58
|
-
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
59
|
-
]
|
55
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
60
56
|
use_ragged: bool
|
61
57
|
extend_no_prefix: bool
|
62
58
|
|
@@ -73,6 +69,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
73
69
|
model_runner: ModelRunner,
|
74
70
|
skip_prefill: bool = False,
|
75
71
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
72
|
+
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
76
73
|
):
|
77
74
|
super().__init__()
|
78
75
|
|
@@ -109,12 +106,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
109
106
|
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
110
107
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
111
108
|
|
112
|
-
self.enable_flashinfer_mla = False
|
113
|
-
if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
114
|
-
if global_server_args_dict["enable_flashinfer_mla"]:
|
115
|
-
self.enable_flashinfer_mla = True
|
116
|
-
global_config.enable_flashinfer_mla = True
|
117
|
-
|
118
109
|
# Allocate buffers
|
119
110
|
global global_workspace_buffer
|
120
111
|
if global_workspace_buffer is None:
|
@@ -124,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
124
115
|
device=model_runner.device,
|
125
116
|
)
|
126
117
|
self.workspace_buffer = global_workspace_buffer
|
118
|
+
|
127
119
|
max_bs = model_runner.req_to_token_pool.size
|
128
120
|
if kv_indptr_buf is None:
|
129
121
|
self.kv_indptr = [
|
@@ -132,24 +124,25 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
132
124
|
)
|
133
125
|
for _ in range(self.num_wrappers)
|
134
126
|
]
|
135
|
-
if self.enable_flashinfer_mla:
|
136
|
-
self.qo_indptr = [
|
137
|
-
torch.zeros(
|
138
|
-
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
139
|
-
)
|
140
|
-
for _ in range(self.num_wrappers)
|
141
|
-
]
|
142
127
|
else:
|
143
128
|
assert self.num_wrappers == 1
|
144
129
|
self.kv_indptr = [kv_indptr_buf]
|
145
130
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
131
|
+
if kv_last_page_len_buf is None:
|
132
|
+
self.kv_last_page_len = torch.ones(
|
133
|
+
(max_bs,), dtype=torch.int32, device=model_runner.device
|
134
|
+
)
|
135
|
+
else:
|
136
|
+
assert self.num_wrappers == 1
|
137
|
+
self.kv_last_page_len = kv_last_page_len_buf
|
138
|
+
|
139
|
+
if not self.skip_prefill:
|
140
|
+
self.qo_indptr = [
|
141
|
+
torch.zeros(
|
142
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
143
|
+
)
|
144
|
+
for _ in range(self.num_wrappers)
|
145
|
+
]
|
153
146
|
|
154
147
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
155
148
|
self.workspace_buffer, "NHD"
|
@@ -162,48 +155,24 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
162
155
|
self.decode_wrappers = []
|
163
156
|
for _ in range(self.num_wrappers):
|
164
157
|
if not skip_prefill:
|
165
|
-
|
166
|
-
|
167
|
-
and not global_server_args_dict["disable_radix_cache"]
|
168
|
-
):
|
169
|
-
# use mla paged prefill
|
170
|
-
self.prefill_wrappers_paged.append(
|
171
|
-
BatchMLAPagedAttentionWrapper(
|
172
|
-
self.workspace_buffer,
|
173
|
-
backend="fa2",
|
174
|
-
)
|
175
|
-
)
|
176
|
-
self.prefill_wrappers_verify.append(
|
177
|
-
BatchMLAPagedAttentionWrapper(
|
178
|
-
self.workspace_buffer,
|
179
|
-
backend="fa2",
|
180
|
-
)
|
181
|
-
)
|
182
|
-
else:
|
183
|
-
self.prefill_wrappers_paged.append(
|
184
|
-
BatchPrefillWithPagedKVCacheWrapper(
|
185
|
-
self.workspace_buffer,
|
186
|
-
"NHD",
|
187
|
-
backend="fa2",
|
188
|
-
)
|
189
|
-
)
|
190
|
-
self.prefill_wrappers_verify.append(
|
191
|
-
BatchPrefillWithPagedKVCacheWrapper(
|
192
|
-
self.workspace_buffer, "NHD"
|
193
|
-
)
|
194
|
-
)
|
195
|
-
if self.enable_flashinfer_mla:
|
196
|
-
self.decode_wrappers.append(
|
197
|
-
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
|
198
|
-
)
|
199
|
-
else:
|
200
|
-
self.decode_wrappers.append(
|
201
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
158
|
+
self.prefill_wrappers_paged.append(
|
159
|
+
BatchPrefillWithPagedKVCacheWrapper(
|
202
160
|
self.workspace_buffer,
|
203
161
|
"NHD",
|
204
|
-
|
162
|
+
backend="fa2",
|
205
163
|
)
|
206
164
|
)
|
165
|
+
self.prefill_wrappers_verify.append(
|
166
|
+
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
167
|
+
)
|
168
|
+
|
169
|
+
self.decode_wrappers.append(
|
170
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
171
|
+
self.workspace_buffer,
|
172
|
+
"NHD",
|
173
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
174
|
+
)
|
175
|
+
)
|
207
176
|
|
208
177
|
# Create indices updater
|
209
178
|
if not skip_prefill:
|
@@ -259,10 +228,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
259
228
|
else:
|
260
229
|
prefix_lens = forward_batch.extend_prefix_lens
|
261
230
|
|
262
|
-
if self.is_multimodal
|
263
|
-
self.enable_flashinfer_mla
|
264
|
-
and not global_server_args_dict["disable_radix_cache"]
|
265
|
-
):
|
231
|
+
if self.is_multimodal:
|
266
232
|
use_ragged = False
|
267
233
|
extend_no_prefix = False
|
268
234
|
else:
|
@@ -316,37 +282,25 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
316
282
|
seq_lens: torch.Tensor,
|
317
283
|
encoder_lens: Optional[torch.Tensor],
|
318
284
|
forward_mode: ForwardMode,
|
319
|
-
spec_info: Optional[
|
285
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
320
286
|
):
|
321
287
|
if forward_mode.is_decode_or_idle():
|
322
288
|
decode_wrappers = []
|
323
289
|
for i in range(self.num_wrappers):
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
)
|
336
|
-
else:
|
337
|
-
decode_wrappers.append(
|
338
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
339
|
-
self.workspace_buffer,
|
340
|
-
"NHD",
|
341
|
-
use_cuda_graph=True,
|
342
|
-
use_tensor_cores=self.decode_use_tensor_cores,
|
343
|
-
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
344
|
-
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
345
|
-
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
346
|
-
:num_tokens
|
347
|
-
],
|
348
|
-
)
|
290
|
+
decode_wrappers.append(
|
291
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
292
|
+
self.workspace_buffer,
|
293
|
+
"NHD",
|
294
|
+
use_cuda_graph=True,
|
295
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
296
|
+
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
297
|
+
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
298
|
+
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
299
|
+
:num_tokens
|
300
|
+
],
|
349
301
|
)
|
302
|
+
)
|
303
|
+
|
350
304
|
seq_lens_sum = seq_lens.sum().item()
|
351
305
|
self.indices_updater_decode.update(
|
352
306
|
req_pool_indices,
|
@@ -398,7 +352,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
398
352
|
seq_lens_sum: int,
|
399
353
|
encoder_lens: Optional[torch.Tensor],
|
400
354
|
forward_mode: ForwardMode,
|
401
|
-
spec_info: Optional[
|
355
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
356
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
402
357
|
):
|
403
358
|
if forward_mode.is_decode_or_idle():
|
404
359
|
self.indices_updater_decode.update(
|
@@ -435,114 +390,64 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
435
390
|
forward_batch: ForwardBatch,
|
436
391
|
save_kv_cache=True,
|
437
392
|
):
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
393
|
+
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
394
|
+
self._get_wrapper_idx(layer)
|
395
|
+
]
|
396
|
+
cache_loc = (
|
397
|
+
forward_batch.out_cache_loc
|
398
|
+
if not layer.is_cross_attention
|
399
|
+
else forward_batch.encoder_out_cache_loc
|
400
|
+
)
|
446
401
|
|
447
|
-
|
448
|
-
# use mla ragged prefill
|
449
|
-
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
450
|
-
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
451
|
-
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
452
|
-
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
453
|
-
causal=True,
|
454
|
-
sm_scale=layer.scaling,
|
455
|
-
logits_soft_cap=logits_soft_cap,
|
456
|
-
)
|
402
|
+
logits_soft_cap = layer.logit_cap
|
457
403
|
|
404
|
+
if not self.forward_metadata.use_ragged:
|
405
|
+
if k is not None:
|
406
|
+
assert v is not None
|
458
407
|
if save_kv_cache:
|
459
408
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
460
|
-
layer,
|
461
|
-
cache_loc,
|
462
|
-
k,
|
463
|
-
v,
|
409
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
464
410
|
)
|
465
|
-
else:
|
466
|
-
# use mla paged prefill
|
467
|
-
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
468
|
-
self._get_wrapper_idx(layer)
|
469
|
-
]
|
470
|
-
if k is not None:
|
471
|
-
assert v is not None
|
472
|
-
if save_kv_cache:
|
473
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
474
|
-
layer, cache_loc, k, v
|
475
|
-
)
|
476
|
-
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
477
|
-
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
478
|
-
|
479
|
-
o = prefill_wrapper_paged.run(
|
480
|
-
qall[:, :, : layer.v_head_dim],
|
481
|
-
qall[:, :, layer.v_head_dim :],
|
482
|
-
k_buf[:, :, : layer.v_head_dim],
|
483
|
-
k_buf[:, :, layer.v_head_dim :],
|
484
|
-
)
|
485
411
|
|
486
|
-
|
412
|
+
o = prefill_wrapper_paged.forward(
|
413
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
414
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
415
|
+
causal=not layer.is_cross_attention,
|
416
|
+
sm_scale=layer.scaling,
|
417
|
+
window_left=layer.sliding_window_size,
|
418
|
+
logits_soft_cap=logits_soft_cap,
|
419
|
+
k_scale=layer.k_scale,
|
420
|
+
v_scale=layer.v_scale,
|
421
|
+
)
|
487
422
|
else:
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
423
|
+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
424
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
425
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
426
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
427
|
+
causal=True,
|
428
|
+
sm_scale=layer.scaling,
|
429
|
+
logits_soft_cap=logits_soft_cap,
|
495
430
|
)
|
496
431
|
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
assert v is not None
|
502
|
-
if save_kv_cache:
|
503
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
504
|
-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
505
|
-
)
|
506
|
-
|
507
|
-
o = prefill_wrapper_paged.forward(
|
432
|
+
if self.forward_metadata.extend_no_prefix:
|
433
|
+
o = o1
|
434
|
+
else:
|
435
|
+
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
508
436
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
509
437
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
510
|
-
causal=
|
438
|
+
causal=False,
|
511
439
|
sm_scale=layer.scaling,
|
512
|
-
|
513
|
-
logits_soft_cap=logits_soft_cap,
|
514
|
-
k_scale=layer.k_scale,
|
515
|
-
v_scale=layer.v_scale,
|
440
|
+
logits_soft_cap=layer.logit_cap,
|
516
441
|
)
|
517
|
-
else:
|
518
|
-
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
519
|
-
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
520
|
-
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
521
|
-
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
522
|
-
causal=True,
|
523
|
-
sm_scale=layer.scaling,
|
524
|
-
logits_soft_cap=logits_soft_cap,
|
525
|
-
)
|
526
|
-
|
527
|
-
if self.forward_metadata.extend_no_prefix:
|
528
|
-
o = o1
|
529
|
-
else:
|
530
|
-
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
531
|
-
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
532
|
-
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
533
|
-
causal=False,
|
534
|
-
sm_scale=layer.scaling,
|
535
|
-
logits_soft_cap=layer.logit_cap,
|
536
|
-
)
|
537
442
|
|
538
|
-
|
443
|
+
o, _ = merge_state(o1, s1, o2, s2)
|
539
444
|
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
445
|
+
if save_kv_cache:
|
446
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
447
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
448
|
+
)
|
544
449
|
|
545
|
-
|
450
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
546
451
|
|
547
452
|
def forward_decode(
|
548
453
|
self,
|
@@ -562,45 +467,23 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
562
467
|
else forward_batch.encoder_out_cache_loc
|
563
468
|
)
|
564
469
|
|
565
|
-
if
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
cache_loc,
|
572
|
-
k,
|
573
|
-
v,
|
574
|
-
)
|
575
|
-
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
576
|
-
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
577
|
-
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
|
578
|
-
o = decode_wrapper.run(
|
579
|
-
reshaped_q[:, :, : layer.v_head_dim],
|
580
|
-
reshaped_q[:, :, layer.v_head_dim :],
|
581
|
-
reshaped_k[:, :, : layer.v_head_dim],
|
582
|
-
reshaped_k[:, :, layer.v_head_dim :],
|
583
|
-
)
|
584
|
-
|
585
|
-
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
586
|
-
else:
|
587
|
-
if k is not None:
|
588
|
-
assert v is not None
|
589
|
-
if save_kv_cache:
|
590
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
591
|
-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
592
|
-
)
|
470
|
+
if k is not None:
|
471
|
+
assert v is not None
|
472
|
+
if save_kv_cache:
|
473
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
474
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
475
|
+
)
|
593
476
|
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
477
|
+
o = decode_wrapper.forward(
|
478
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
479
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
480
|
+
sm_scale=layer.scaling,
|
481
|
+
logits_soft_cap=layer.logit_cap,
|
482
|
+
k_scale=layer.k_scale,
|
483
|
+
v_scale=layer.v_scale,
|
484
|
+
)
|
602
485
|
|
603
|
-
|
486
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
604
487
|
|
605
488
|
def _get_wrapper_idx(self, layer: RadixAttention):
|
606
489
|
if self.num_wrappers == 1:
|
@@ -648,11 +531,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
648
531
|
req_pool_indices: torch.Tensor,
|
649
532
|
seq_lens: torch.Tensor,
|
650
533
|
seq_lens_sum: int,
|
651
|
-
decode_wrappers: List[
|
652
|
-
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
653
|
-
],
|
534
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
654
535
|
encoder_lens: Optional[torch.Tensor],
|
655
|
-
spec_info: Optional[
|
536
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
656
537
|
):
|
657
538
|
# Keep the signature for type checking. It will be assigned during runtime.
|
658
539
|
raise NotImplementedError()
|
@@ -662,11 +543,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
662
543
|
req_pool_indices: torch.Tensor,
|
663
544
|
seq_lens: torch.Tensor,
|
664
545
|
seq_lens_sum: int,
|
665
|
-
decode_wrappers: List[
|
666
|
-
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
667
|
-
],
|
546
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
668
547
|
encoder_lens: Optional[torch.Tensor],
|
669
|
-
spec_info: Optional[
|
548
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
670
549
|
):
|
671
550
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
672
551
|
self.call_begin_forward(
|
@@ -686,7 +565,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
686
565
|
seq_lens_sum: int,
|
687
566
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
688
567
|
encoder_lens: Optional[torch.Tensor],
|
689
|
-
spec_info: Optional[
|
568
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
690
569
|
):
|
691
570
|
for wrapper_id in range(2):
|
692
571
|
if wrapper_id == 0:
|
@@ -720,7 +599,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
720
599
|
seq_lens_sum: int,
|
721
600
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
722
601
|
encoder_lens: Optional[torch.Tensor],
|
723
|
-
spec_info: Optional[
|
602
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
724
603
|
):
|
725
604
|
for wrapper_id in range(2):
|
726
605
|
if wrapper_id == 0:
|
@@ -745,15 +624,13 @@ class FlashInferIndicesUpdaterDecode:
|
|
745
624
|
|
746
625
|
def call_begin_forward(
|
747
626
|
self,
|
748
|
-
wrapper:
|
749
|
-
BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
750
|
-
],
|
627
|
+
wrapper: BatchDecodeWithPagedKVCacheWrapper,
|
751
628
|
req_pool_indices: torch.Tensor,
|
752
629
|
paged_kernel_lens: torch.Tensor,
|
753
630
|
paged_kernel_lens_sum: int,
|
754
631
|
kv_indptr: torch.Tensor,
|
755
632
|
kv_start_idx: torch.Tensor,
|
756
|
-
spec_info: Optional[
|
633
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
757
634
|
):
|
758
635
|
if spec_info is None:
|
759
636
|
bs = len(req_pool_indices)
|
@@ -772,40 +649,21 @@ class FlashInferIndicesUpdaterDecode:
|
|
772
649
|
self.req_to_token.shape[1],
|
773
650
|
)
|
774
651
|
else:
|
652
|
+
assert isinstance(spec_info, EagleDraftInput)
|
775
653
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
776
654
|
bs = kv_indptr.shape[0] - 1
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
64,
|
790
|
-
1,
|
791
|
-
False,
|
792
|
-
sm_scale,
|
793
|
-
self.data_type,
|
794
|
-
self.data_type,
|
795
|
-
)
|
796
|
-
else:
|
797
|
-
wrapper.begin_forward(
|
798
|
-
kv_indptr,
|
799
|
-
kv_indices,
|
800
|
-
self.kv_last_page_len[:bs],
|
801
|
-
self.num_qo_heads,
|
802
|
-
self.num_kv_heads,
|
803
|
-
self.head_dim,
|
804
|
-
1,
|
805
|
-
data_type=self.data_type,
|
806
|
-
q_data_type=self.q_data_type,
|
807
|
-
non_blocking=True,
|
808
|
-
)
|
655
|
+
wrapper.begin_forward(
|
656
|
+
kv_indptr,
|
657
|
+
kv_indices,
|
658
|
+
self.kv_last_page_len[:bs],
|
659
|
+
self.num_qo_heads,
|
660
|
+
self.num_kv_heads,
|
661
|
+
self.head_dim,
|
662
|
+
1,
|
663
|
+
data_type=self.data_type,
|
664
|
+
q_data_type=self.q_data_type,
|
665
|
+
non_blocking=True,
|
666
|
+
)
|
809
667
|
|
810
668
|
|
811
669
|
class FlashInferIndicesUpdaterPrefill:
|
@@ -845,12 +703,10 @@ class FlashInferIndicesUpdaterPrefill:
|
|
845
703
|
seq_lens: torch.Tensor,
|
846
704
|
seq_lens_sum: int,
|
847
705
|
prefix_lens: torch.Tensor,
|
848
|
-
prefill_wrappers: List[
|
849
|
-
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
850
|
-
],
|
706
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
851
707
|
use_ragged: bool,
|
852
708
|
encoder_lens: Optional[torch.Tensor],
|
853
|
-
spec_info: Optional[
|
709
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
854
710
|
):
|
855
711
|
# Keep the signature for type checking. It will be assigned during runtime.
|
856
712
|
raise NotImplementedError()
|
@@ -861,12 +717,10 @@ class FlashInferIndicesUpdaterPrefill:
|
|
861
717
|
seq_lens: torch.Tensor,
|
862
718
|
seq_lens_sum: int,
|
863
719
|
prefix_lens: torch.Tensor,
|
864
|
-
prefill_wrappers: List[
|
865
|
-
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
866
|
-
],
|
720
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
867
721
|
use_ragged: bool,
|
868
722
|
encoder_lens: Optional[torch.Tensor],
|
869
|
-
spec_info: Optional[
|
723
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
870
724
|
):
|
871
725
|
if use_ragged:
|
872
726
|
paged_kernel_lens = prefix_lens
|
@@ -899,7 +753,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
899
753
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
900
754
|
use_ragged: bool,
|
901
755
|
encoder_lens: Optional[torch.Tensor],
|
902
|
-
spec_info: Optional[
|
756
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
903
757
|
):
|
904
758
|
for wrapper_id in range(2):
|
905
759
|
if wrapper_id == 0:
|
@@ -940,7 +794,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
940
794
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
941
795
|
use_ragged: bool,
|
942
796
|
encoder_lens: Optional[torch.Tensor],
|
943
|
-
spec_info: Optional[
|
797
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
944
798
|
):
|
945
799
|
for wrapper_id in range(2):
|
946
800
|
if wrapper_id == 0:
|
@@ -972,9 +826,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
972
826
|
def call_begin_forward(
|
973
827
|
self,
|
974
828
|
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
975
|
-
wrapper_paged:
|
976
|
-
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
977
|
-
],
|
829
|
+
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
|
978
830
|
req_pool_indices: torch.Tensor,
|
979
831
|
paged_kernel_lens: torch.Tensor,
|
980
832
|
paged_kernel_lens_sum: int,
|
@@ -984,10 +836,11 @@ class FlashInferIndicesUpdaterPrefill:
|
|
984
836
|
kv_indptr: torch.Tensor,
|
985
837
|
qo_indptr: torch.Tensor,
|
986
838
|
use_ragged: bool,
|
987
|
-
spec_info: Optional[
|
839
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
988
840
|
):
|
989
|
-
bs = len(
|
841
|
+
bs = len(seq_lens)
|
990
842
|
if spec_info is None:
|
843
|
+
assert len(seq_lens) == len(req_pool_indices)
|
991
844
|
# Normal extend
|
992
845
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
993
846
|
kv_indptr = kv_indptr[: bs + 1]
|
@@ -1010,72 +863,49 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1010
863
|
qo_indptr = qo_indptr[: bs + 1]
|
1011
864
|
custom_mask = None
|
1012
865
|
else:
|
866
|
+
assert isinstance(spec_info, EagleDraftInput) or isinstance(
|
867
|
+
spec_info, EagleVerifyInput
|
868
|
+
)
|
1013
869
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
1014
870
|
spec_info.generate_attn_arg_prefill(
|
1015
871
|
req_pool_indices,
|
1016
872
|
paged_kernel_lens,
|
873
|
+
paged_kernel_lens_sum,
|
1017
874
|
self.req_to_token,
|
1018
875
|
)
|
1019
876
|
)
|
1020
877
|
|
1021
878
|
# extend part
|
1022
879
|
if use_ragged:
|
1023
|
-
|
1024
|
-
|
1025
|
-
qo_indptr=qo_indptr,
|
1026
|
-
kv_indptr=qo_indptr,
|
1027
|
-
num_qo_heads=self.num_qo_heads,
|
1028
|
-
num_kv_heads=self.num_kv_heads,
|
1029
|
-
head_dim_qk=192,
|
1030
|
-
head_dim_vo=128,
|
1031
|
-
q_data_type=self.q_data_type,
|
1032
|
-
)
|
1033
|
-
else:
|
1034
|
-
wrapper_ragged.begin_forward(
|
1035
|
-
qo_indptr,
|
1036
|
-
qo_indptr,
|
1037
|
-
self.num_qo_heads,
|
1038
|
-
self.num_kv_heads,
|
1039
|
-
self.head_dim,
|
1040
|
-
q_data_type=self.q_data_type,
|
1041
|
-
)
|
1042
|
-
|
1043
|
-
if not global_config.enable_flashinfer_mla:
|
1044
|
-
# cached part
|
1045
|
-
wrapper_paged.begin_forward(
|
880
|
+
wrapper_ragged.begin_forward(
|
881
|
+
qo_indptr,
|
1046
882
|
qo_indptr,
|
1047
|
-
kv_indptr,
|
1048
|
-
kv_indices,
|
1049
|
-
self.kv_last_page_len[:bs],
|
1050
883
|
self.num_qo_heads,
|
1051
884
|
self.num_kv_heads,
|
1052
885
|
self.head_dim,
|
1053
|
-
1,
|
1054
886
|
q_data_type=self.q_data_type,
|
1055
|
-
custom_mask=custom_mask,
|
1056
|
-
non_blocking=True,
|
1057
|
-
)
|
1058
|
-
elif (
|
1059
|
-
global_config.enable_flashinfer_mla
|
1060
|
-
and not global_server_args_dict["disable_radix_cache"]
|
1061
|
-
):
|
1062
|
-
# mla paged prefill
|
1063
|
-
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
|
1064
|
-
wrapper_paged.plan(
|
1065
|
-
qo_indptr,
|
1066
|
-
kv_indptr,
|
1067
|
-
kv_indices,
|
1068
|
-
kv_len_arr,
|
1069
|
-
self.num_qo_heads,
|
1070
|
-
512,
|
1071
|
-
64,
|
1072
|
-
1,
|
1073
|
-
True,
|
1074
|
-
1 / math.sqrt(192),
|
1075
|
-
self.data_type,
|
1076
|
-
self.data_type,
|
1077
887
|
)
|
1078
888
|
|
889
|
+
# cached part
|
890
|
+
wrapper_paged.begin_forward(
|
891
|
+
qo_indptr,
|
892
|
+
kv_indptr,
|
893
|
+
kv_indices,
|
894
|
+
self.kv_last_page_len[:bs],
|
895
|
+
self.num_qo_heads,
|
896
|
+
self.num_kv_heads,
|
897
|
+
self.head_dim,
|
898
|
+
1,
|
899
|
+
q_data_type=self.q_data_type,
|
900
|
+
custom_mask=custom_mask,
|
901
|
+
non_blocking=True,
|
902
|
+
)
|
903
|
+
|
904
|
+
|
905
|
+
# Use as a fast path to override the indptr in flashinfer's plan function
|
906
|
+
# This is used to remove some host-to-device copy overhead.
|
907
|
+
global global_override_indptr_cpu
|
908
|
+
|
1079
909
|
|
1080
910
|
class FlashInferMultiStepDraftBackend:
|
1081
911
|
"""
|
@@ -1094,7 +924,8 @@ class FlashInferMultiStepDraftBackend:
|
|
1094
924
|
self.topk = topk
|
1095
925
|
self.speculative_num_steps = speculative_num_steps
|
1096
926
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
1097
|
-
|
927
|
+
|
928
|
+
max_bs = model_runner.req_to_token_pool.size * self.topk
|
1098
929
|
self.kv_indptr = torch.zeros(
|
1099
930
|
(
|
1100
931
|
self.speculative_num_steps,
|
@@ -1103,6 +934,9 @@ class FlashInferMultiStepDraftBackend:
|
|
1103
934
|
dtype=torch.int32,
|
1104
935
|
device=model_runner.device,
|
1105
936
|
)
|
937
|
+
self.kv_last_page_len = torch.ones(
|
938
|
+
(max_bs,), dtype=torch.int32, device=model_runner.device
|
939
|
+
)
|
1106
940
|
self.attn_backends = []
|
1107
941
|
for i in range(self.speculative_num_steps):
|
1108
942
|
self.attn_backends.append(
|
@@ -1110,9 +944,12 @@ class FlashInferMultiStepDraftBackend:
|
|
1110
944
|
model_runner,
|
1111
945
|
skip_prefill=True,
|
1112
946
|
kv_indptr_buf=self.kv_indptr[i],
|
947
|
+
kv_last_page_len_buf=self.kv_last_page_len,
|
1113
948
|
)
|
1114
949
|
)
|
950
|
+
|
1115
951
|
self.max_context_len = self.attn_backends[0].max_context_len
|
952
|
+
|
1116
953
|
# Cached variables for generate_draft_decode_kv_indices
|
1117
954
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
1118
955
|
|
@@ -1142,13 +979,23 @@ class FlashInferMultiStepDraftBackend:
|
|
1142
979
|
triton.next_power_of_2(bs),
|
1143
980
|
)
|
1144
981
|
|
982
|
+
assert forward_batch.spec_info is not None
|
983
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
984
|
+
|
985
|
+
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
|
986
|
+
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
|
987
|
+
global global_override_indptr_cpu
|
988
|
+
|
1145
989
|
for i in range(self.speculative_num_steps - 1):
|
1146
990
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
1147
991
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
1148
992
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
1149
993
|
]
|
994
|
+
global_override_indptr_cpu = indptr_cpu_whole[i]
|
1150
995
|
call_fn(i, forward_batch)
|
1151
996
|
|
997
|
+
global_override_indptr_cpu = None
|
998
|
+
|
1152
999
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
1153
1000
|
kv_indices = torch.zeros(
|
1154
1001
|
(
|
@@ -1160,6 +1007,8 @@ class FlashInferMultiStepDraftBackend:
|
|
1160
1007
|
)
|
1161
1008
|
|
1162
1009
|
def call_fn(i, forward_batch):
|
1010
|
+
assert forward_batch.spec_info is not None
|
1011
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
1163
1012
|
forward_batch.spec_info.kv_indptr = (
|
1164
1013
|
forward_batch.spec_info.kv_indptr.clone()
|
1165
1014
|
)
|
@@ -1176,6 +1025,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1176
1025
|
dtype=torch.int32,
|
1177
1026
|
device="cuda",
|
1178
1027
|
)
|
1028
|
+
|
1179
1029
|
for i in range(self.speculative_num_steps):
|
1180
1030
|
self.attn_backends[i].init_cuda_graph_state(
|
1181
1031
|
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
@@ -1209,48 +1059,12 @@ class FlashInferMultiStepDraftBackend:
|
|
1209
1059
|
encoder_lens=None,
|
1210
1060
|
forward_mode=ForwardMode.DECODE,
|
1211
1061
|
spec_info=forward_batch.spec_info,
|
1062
|
+
seq_lens_cpu=None,
|
1212
1063
|
)
|
1213
1064
|
|
1214
1065
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
1215
1066
|
|
1216
1067
|
|
1217
|
-
@triton.jit
|
1218
|
-
def create_flashinfer_kv_indices_triton(
|
1219
|
-
req_to_token_ptr, # [max_batch, max_context_len]
|
1220
|
-
req_pool_indices_ptr,
|
1221
|
-
page_kernel_lens_ptr,
|
1222
|
-
kv_indptr,
|
1223
|
-
kv_start_idx,
|
1224
|
-
kv_indices_ptr,
|
1225
|
-
req_to_token_ptr_stride: tl.constexpr,
|
1226
|
-
):
|
1227
|
-
BLOCK_SIZE: tl.constexpr = 512
|
1228
|
-
pid = tl.program_id(axis=0)
|
1229
|
-
|
1230
|
-
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
1231
|
-
kv_indices_offset = tl.load(kv_indptr + pid)
|
1232
|
-
|
1233
|
-
kv_start = 0
|
1234
|
-
kv_end = 0
|
1235
|
-
if kv_start_idx:
|
1236
|
-
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
1237
|
-
kv_end = kv_start
|
1238
|
-
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
1239
|
-
|
1240
|
-
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
1241
|
-
for i in range(num_loop):
|
1242
|
-
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
1243
|
-
mask = offset < kv_end - kv_start
|
1244
|
-
data = tl.load(
|
1245
|
-
req_to_token_ptr
|
1246
|
-
+ req_pool_index * req_to_token_ptr_stride
|
1247
|
-
+ kv_start
|
1248
|
-
+ offset,
|
1249
|
-
mask=mask,
|
1250
|
-
)
|
1251
|
-
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
1252
|
-
|
1253
|
-
|
1254
1068
|
def should_use_tensor_core(
|
1255
1069
|
kv_cache_dtype: torch.dtype,
|
1256
1070
|
num_attention_heads: int,
|
@@ -1272,6 +1086,21 @@ def should_use_tensor_core(
|
|
1272
1086
|
if env_override is not None:
|
1273
1087
|
return env_override.lower() == "true"
|
1274
1088
|
|
1089
|
+
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
1090
|
+
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
1091
|
+
try:
|
1092
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
1093
|
+
|
1094
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
1095
|
+
num_attention_heads,
|
1096
|
+
num_kv_heads,
|
1097
|
+
):
|
1098
|
+
return True
|
1099
|
+
else:
|
1100
|
+
return False
|
1101
|
+
except (ImportError, AttributeError):
|
1102
|
+
pass
|
1103
|
+
|
1275
1104
|
# Calculate GQA group size
|
1276
1105
|
gqa_group_size = num_attention_heads // num_kv_heads
|
1277
1106
|
|
@@ -1301,12 +1130,18 @@ def fast_decode_plan(
|
|
1301
1130
|
sm_scale: Optional[float] = None,
|
1302
1131
|
rope_scale: Optional[float] = None,
|
1303
1132
|
rope_theta: Optional[float] = None,
|
1304
|
-
|
1133
|
+
non_blocking: bool = True,
|
1305
1134
|
) -> None:
|
1306
|
-
"""
|
1135
|
+
"""
|
1136
|
+
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
1137
|
+
Modifications:
|
1138
|
+
- Remove unnecessary device-to-device copy for the cuda graph buffers.
|
1139
|
+
- Remove unnecessary host-to-device copy for the metadata buffers.
|
1140
|
+
"""
|
1307
1141
|
batch_size = len(last_page_len)
|
1308
1142
|
if logits_soft_cap is None:
|
1309
1143
|
logits_soft_cap = 0.0
|
1144
|
+
|
1310
1145
|
if self.is_cuda_graph_enabled:
|
1311
1146
|
if batch_size != self._fixed_batch_size:
|
1312
1147
|
raise ValueError(
|
@@ -1319,13 +1154,19 @@ def fast_decode_plan(
|
|
1319
1154
|
raise ValueError(
|
1320
1155
|
"The size of indices should be less than or equal to the allocated buffer"
|
1321
1156
|
)
|
1157
|
+
# Skip these copies
|
1158
|
+
# self._paged_kv_indptr_buf.copy_(indptr)
|
1159
|
+
# self._paged_kv_indices_buf[: len(indices)] = indices
|
1160
|
+
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
1322
1161
|
else:
|
1323
1162
|
self._paged_kv_indptr_buf = indptr
|
1324
1163
|
self._paged_kv_indices_buf = indices
|
1325
1164
|
self._paged_kv_last_page_len_buf = last_page_len
|
1165
|
+
|
1326
1166
|
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
1327
1167
|
if not q_data_type:
|
1328
1168
|
q_data_type = data_type
|
1169
|
+
|
1329
1170
|
if not hasattr(self, "empty_q_data"):
|
1330
1171
|
self.empty_q_data = torch.empty(
|
1331
1172
|
0,
|
@@ -1342,6 +1183,7 @@ def fast_decode_plan(
|
|
1342
1183
|
),
|
1343
1184
|
)
|
1344
1185
|
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
1186
|
+
|
1345
1187
|
empty_q_data = self.empty_q_data
|
1346
1188
|
empty_kv_cache = self.empty_kv_cache
|
1347
1189
|
stream = torch.cuda.current_stream()
|