sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -7,28 +7,26 @@ FlashInfer is faster and Triton is easier to customize.
|
|
7
7
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
8
|
"""
|
9
9
|
|
10
|
-
import math
|
11
10
|
import os
|
12
11
|
from dataclasses import dataclass
|
13
12
|
from enum import Enum, auto
|
14
13
|
from functools import partial
|
15
|
-
from typing import TYPE_CHECKING, List, Optional, Union
|
14
|
+
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
16
15
|
|
17
16
|
import torch
|
18
17
|
import triton
|
19
|
-
import triton.language as tl
|
20
18
|
|
21
19
|
from sglang.global_config import global_config
|
22
|
-
from sglang.srt.layers.attention import AttentionBackend
|
20
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
21
|
+
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
23
22
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
24
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
25
23
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
24
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
26
25
|
from sglang.srt.utils import is_flashinfer_available
|
27
26
|
|
28
27
|
if TYPE_CHECKING:
|
29
28
|
from sglang.srt.layers.radix_attention import RadixAttention
|
30
29
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
31
|
-
from sglang.srt.speculative.spec_info import SpecInfo
|
32
30
|
|
33
31
|
if is_flashinfer_available():
|
34
32
|
from flashinfer import (
|
@@ -37,7 +35,7 @@ if is_flashinfer_available():
|
|
37
35
|
BatchPrefillWithRaggedKVCacheWrapper,
|
38
36
|
)
|
39
37
|
from flashinfer.cascade import merge_state
|
40
|
-
from flashinfer.
|
38
|
+
from flashinfer.decode import _get_range_buf, get_seq_lens
|
41
39
|
|
42
40
|
|
43
41
|
class WrapperDispatch(Enum):
|
@@ -47,16 +45,12 @@ class WrapperDispatch(Enum):
|
|
47
45
|
|
48
46
|
@dataclass
|
49
47
|
class DecodeMetadata:
|
50
|
-
decode_wrappers: List[
|
51
|
-
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
52
|
-
]
|
48
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
53
49
|
|
54
50
|
|
55
51
|
@dataclass
|
56
52
|
class PrefillMetadata:
|
57
|
-
prefill_wrappers: List[
|
58
|
-
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
59
|
-
]
|
53
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
60
54
|
use_ragged: bool
|
61
55
|
extend_no_prefix: bool
|
62
56
|
|
@@ -73,11 +67,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
73
67
|
model_runner: ModelRunner,
|
74
68
|
skip_prefill: bool = False,
|
75
69
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
70
|
+
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
76
71
|
):
|
77
72
|
super().__init__()
|
78
73
|
|
79
|
-
self.is_multimodal = model_runner.model_config.is_multimodal
|
80
|
-
|
81
74
|
# Parse constants
|
82
75
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
83
76
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
@@ -89,6 +82,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
89
82
|
)
|
90
83
|
self.max_context_len = model_runner.model_config.context_len
|
91
84
|
self.skip_prefill = skip_prefill
|
85
|
+
self.is_multimodal = model_runner.model_config.is_multimodal
|
92
86
|
|
93
87
|
assert not (
|
94
88
|
model_runner.sliding_window_size is not None
|
@@ -109,12 +103,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
109
103
|
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
110
104
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
111
105
|
|
112
|
-
self.enable_flashinfer_mla = False
|
113
|
-
if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
114
|
-
if global_server_args_dict["enable_flashinfer_mla"]:
|
115
|
-
self.enable_flashinfer_mla = True
|
116
|
-
global_config.enable_flashinfer_mla = True
|
117
|
-
|
118
106
|
# Allocate buffers
|
119
107
|
global global_workspace_buffer
|
120
108
|
if global_workspace_buffer is None:
|
@@ -132,24 +120,25 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
132
120
|
)
|
133
121
|
for _ in range(self.num_wrappers)
|
134
122
|
]
|
135
|
-
if self.enable_flashinfer_mla:
|
136
|
-
self.qo_indptr = [
|
137
|
-
torch.zeros(
|
138
|
-
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
139
|
-
)
|
140
|
-
for _ in range(self.num_wrappers)
|
141
|
-
]
|
142
123
|
else:
|
143
124
|
assert self.num_wrappers == 1
|
144
125
|
self.kv_indptr = [kv_indptr_buf]
|
145
126
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
127
|
+
if kv_last_page_len_buf is None:
|
128
|
+
self.kv_last_page_len = torch.ones(
|
129
|
+
(max_bs,), dtype=torch.int32, device=model_runner.device
|
130
|
+
)
|
131
|
+
else:
|
132
|
+
assert self.num_wrappers == 1
|
133
|
+
self.kv_last_page_len = kv_last_page_len_buf
|
134
|
+
|
135
|
+
if not self.skip_prefill:
|
136
|
+
self.qo_indptr = [
|
137
|
+
torch.zeros(
|
138
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
139
|
+
)
|
140
|
+
for _ in range(self.num_wrappers)
|
141
|
+
]
|
153
142
|
|
154
143
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
155
144
|
self.workspace_buffer, "NHD"
|
@@ -162,60 +151,39 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
162
151
|
self.decode_wrappers = []
|
163
152
|
for _ in range(self.num_wrappers):
|
164
153
|
if not skip_prefill:
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
self.prefill_wrappers_paged.append(
|
171
|
-
BatchMLAPagedAttentionWrapper(
|
172
|
-
self.workspace_buffer,
|
173
|
-
backend="fa2",
|
174
|
-
)
|
175
|
-
)
|
176
|
-
self.prefill_wrappers_verify.append(
|
177
|
-
BatchMLAPagedAttentionWrapper(
|
178
|
-
self.workspace_buffer,
|
179
|
-
backend="fa2",
|
180
|
-
)
|
181
|
-
)
|
182
|
-
else:
|
183
|
-
self.prefill_wrappers_paged.append(
|
184
|
-
BatchPrefillWithPagedKVCacheWrapper(
|
185
|
-
self.workspace_buffer,
|
186
|
-
"NHD",
|
187
|
-
backend="fa2",
|
188
|
-
)
|
189
|
-
)
|
190
|
-
self.prefill_wrappers_verify.append(
|
191
|
-
BatchPrefillWithPagedKVCacheWrapper(
|
192
|
-
self.workspace_buffer, "NHD"
|
193
|
-
)
|
154
|
+
self.prefill_wrappers_paged.append(
|
155
|
+
BatchPrefillWithPagedKVCacheWrapper(
|
156
|
+
self.workspace_buffer,
|
157
|
+
"NHD",
|
158
|
+
backend="fa2",
|
194
159
|
)
|
195
|
-
if self.enable_flashinfer_mla:
|
196
|
-
self.decode_wrappers.append(
|
197
|
-
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
|
198
160
|
)
|
199
|
-
|
200
|
-
|
201
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
161
|
+
self.prefill_wrappers_verify.append(
|
162
|
+
BatchPrefillWithPagedKVCacheWrapper(
|
202
163
|
self.workspace_buffer,
|
203
164
|
"NHD",
|
204
|
-
use_tensor_cores=self.decode_use_tensor_cores,
|
205
165
|
)
|
206
166
|
)
|
167
|
+
self.decode_wrappers.append(
|
168
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
169
|
+
self.workspace_buffer,
|
170
|
+
"NHD",
|
171
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
172
|
+
)
|
173
|
+
)
|
207
174
|
|
208
175
|
# Create indices updater
|
209
176
|
if not skip_prefill:
|
210
177
|
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
211
178
|
model_runner, self
|
212
|
-
)
|
179
|
+
) # for verify
|
213
180
|
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
|
214
181
|
|
215
182
|
# Other metadata
|
216
183
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
217
184
|
self.decode_cuda_graph_metadata = {}
|
218
|
-
self.prefill_cuda_graph_metadata = {}
|
185
|
+
self.prefill_cuda_graph_metadata = {} # For verify
|
186
|
+
self.draft_extend_cuda_graph_metadata = {} # For draft extend
|
219
187
|
|
220
188
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
221
189
|
if forward_batch.forward_mode.is_decode_or_idle():
|
@@ -259,10 +227,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
259
227
|
else:
|
260
228
|
prefix_lens = forward_batch.extend_prefix_lens
|
261
229
|
|
262
|
-
if self.is_multimodal
|
263
|
-
self.enable_flashinfer_mla
|
264
|
-
and not global_server_args_dict["disable_radix_cache"]
|
265
|
-
):
|
230
|
+
if self.is_multimodal:
|
266
231
|
use_ragged = False
|
267
232
|
extend_no_prefix = False
|
268
233
|
else:
|
@@ -316,37 +281,24 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
316
281
|
seq_lens: torch.Tensor,
|
317
282
|
encoder_lens: Optional[torch.Tensor],
|
318
283
|
forward_mode: ForwardMode,
|
319
|
-
spec_info: Optional[
|
284
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
320
285
|
):
|
321
286
|
if forward_mode.is_decode_or_idle():
|
322
287
|
decode_wrappers = []
|
323
288
|
for i in range(self.num_wrappers):
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
)
|
336
|
-
else:
|
337
|
-
decode_wrappers.append(
|
338
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
339
|
-
self.workspace_buffer,
|
340
|
-
"NHD",
|
341
|
-
use_cuda_graph=True,
|
342
|
-
use_tensor_cores=self.decode_use_tensor_cores,
|
343
|
-
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
344
|
-
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
345
|
-
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
346
|
-
:num_tokens
|
347
|
-
],
|
348
|
-
)
|
289
|
+
decode_wrappers.append(
|
290
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
291
|
+
self.workspace_buffer,
|
292
|
+
"NHD",
|
293
|
+
use_cuda_graph=True,
|
294
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
295
|
+
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
296
|
+
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
297
|
+
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
298
|
+
:num_tokens
|
299
|
+
],
|
349
300
|
)
|
301
|
+
)
|
350
302
|
seq_lens_sum = seq_lens.sum().item()
|
351
303
|
self.indices_updater_decode.update(
|
352
304
|
req_pool_indices,
|
@@ -358,6 +310,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
358
310
|
)
|
359
311
|
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
360
312
|
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
313
|
+
for i in range(self.num_wrappers):
|
314
|
+
decode_wrappers[i].begin_forward = partial(
|
315
|
+
fast_decode_plan, decode_wrappers[i]
|
316
|
+
)
|
361
317
|
elif forward_mode.is_target_verify():
|
362
318
|
prefill_wrappers = []
|
363
319
|
for i in range(self.num_wrappers):
|
@@ -398,7 +354,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
398
354
|
seq_lens_sum: int,
|
399
355
|
encoder_lens: Optional[torch.Tensor],
|
400
356
|
forward_mode: ForwardMode,
|
401
|
-
spec_info: Optional[
|
357
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
358
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
402
359
|
):
|
403
360
|
if forward_mode.is_decode_or_idle():
|
404
361
|
self.indices_updater_decode.update(
|
@@ -435,114 +392,64 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
435
392
|
forward_batch: ForwardBatch,
|
436
393
|
save_kv_cache=True,
|
437
394
|
):
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
395
|
+
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
396
|
+
self._get_wrapper_idx(layer)
|
397
|
+
]
|
398
|
+
cache_loc = (
|
399
|
+
forward_batch.out_cache_loc
|
400
|
+
if not layer.is_cross_attention
|
401
|
+
else forward_batch.encoder_out_cache_loc
|
402
|
+
)
|
446
403
|
|
447
|
-
|
448
|
-
# use mla ragged prefill
|
449
|
-
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
450
|
-
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
451
|
-
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
452
|
-
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
453
|
-
causal=True,
|
454
|
-
sm_scale=layer.scaling,
|
455
|
-
logits_soft_cap=logits_soft_cap,
|
456
|
-
)
|
404
|
+
logits_soft_cap = layer.logit_cap
|
457
405
|
|
406
|
+
if not self.forward_metadata.use_ragged:
|
407
|
+
if k is not None:
|
408
|
+
assert v is not None
|
458
409
|
if save_kv_cache:
|
459
410
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
460
|
-
layer,
|
461
|
-
cache_loc,
|
462
|
-
k,
|
463
|
-
v,
|
411
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
464
412
|
)
|
465
|
-
else:
|
466
|
-
# use mla paged prefill
|
467
|
-
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
468
|
-
self._get_wrapper_idx(layer)
|
469
|
-
]
|
470
|
-
if k is not None:
|
471
|
-
assert v is not None
|
472
|
-
if save_kv_cache:
|
473
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
474
|
-
layer, cache_loc, k, v
|
475
|
-
)
|
476
|
-
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
477
|
-
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
478
|
-
|
479
|
-
o = prefill_wrapper_paged.run(
|
480
|
-
qall[:, :, : layer.v_head_dim],
|
481
|
-
qall[:, :, layer.v_head_dim :],
|
482
|
-
k_buf[:, :, : layer.v_head_dim],
|
483
|
-
k_buf[:, :, layer.v_head_dim :],
|
484
|
-
)
|
485
413
|
|
486
|
-
|
414
|
+
o = prefill_wrapper_paged.forward(
|
415
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
416
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
417
|
+
causal=not layer.is_cross_attention,
|
418
|
+
sm_scale=layer.scaling,
|
419
|
+
window_left=layer.sliding_window_size,
|
420
|
+
logits_soft_cap=logits_soft_cap,
|
421
|
+
k_scale=layer.k_scale,
|
422
|
+
v_scale=layer.v_scale,
|
423
|
+
)
|
487
424
|
else:
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
425
|
+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
426
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
427
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
428
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
429
|
+
causal=True,
|
430
|
+
sm_scale=layer.scaling,
|
431
|
+
logits_soft_cap=logits_soft_cap,
|
495
432
|
)
|
496
433
|
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
assert v is not None
|
502
|
-
if save_kv_cache:
|
503
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
504
|
-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
505
|
-
)
|
506
|
-
|
507
|
-
o = prefill_wrapper_paged.forward(
|
434
|
+
if self.forward_metadata.extend_no_prefix:
|
435
|
+
o = o1
|
436
|
+
else:
|
437
|
+
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
508
438
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
509
439
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
510
|
-
causal=
|
511
|
-
sm_scale=layer.scaling,
|
512
|
-
window_left=layer.sliding_window_size,
|
513
|
-
logits_soft_cap=logits_soft_cap,
|
514
|
-
k_scale=layer.k_scale,
|
515
|
-
v_scale=layer.v_scale,
|
516
|
-
)
|
517
|
-
else:
|
518
|
-
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
519
|
-
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
520
|
-
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
521
|
-
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
522
|
-
causal=True,
|
440
|
+
causal=False,
|
523
441
|
sm_scale=layer.scaling,
|
524
442
|
logits_soft_cap=logits_soft_cap,
|
525
443
|
)
|
526
444
|
|
527
|
-
|
528
|
-
o = o1
|
529
|
-
else:
|
530
|
-
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
531
|
-
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
532
|
-
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
533
|
-
causal=False,
|
534
|
-
sm_scale=layer.scaling,
|
535
|
-
logits_soft_cap=layer.logit_cap,
|
536
|
-
)
|
537
|
-
|
538
|
-
o, _ = merge_state(o1, s1, o2, s2)
|
445
|
+
o, _ = merge_state(o1, s1, o2, s2)
|
539
446
|
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
447
|
+
if save_kv_cache:
|
448
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
449
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
450
|
+
)
|
544
451
|
|
545
|
-
|
452
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
546
453
|
|
547
454
|
def forward_decode(
|
548
455
|
self,
|
@@ -562,45 +469,23 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
562
469
|
else forward_batch.encoder_out_cache_loc
|
563
470
|
)
|
564
471
|
|
565
|
-
if
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
cache_loc,
|
572
|
-
k,
|
573
|
-
v,
|
574
|
-
)
|
575
|
-
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
576
|
-
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
577
|
-
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
|
578
|
-
o = decode_wrapper.run(
|
579
|
-
reshaped_q[:, :, : layer.v_head_dim],
|
580
|
-
reshaped_q[:, :, layer.v_head_dim :],
|
581
|
-
reshaped_k[:, :, : layer.v_head_dim],
|
582
|
-
reshaped_k[:, :, layer.v_head_dim :],
|
583
|
-
)
|
584
|
-
|
585
|
-
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
586
|
-
else:
|
587
|
-
if k is not None:
|
588
|
-
assert v is not None
|
589
|
-
if save_kv_cache:
|
590
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
591
|
-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
592
|
-
)
|
472
|
+
if k is not None:
|
473
|
+
assert v is not None
|
474
|
+
if save_kv_cache:
|
475
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
476
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
477
|
+
)
|
593
478
|
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
479
|
+
o = decode_wrapper.forward(
|
480
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
481
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
482
|
+
sm_scale=layer.scaling,
|
483
|
+
logits_soft_cap=layer.logit_cap,
|
484
|
+
k_scale=layer.k_scale,
|
485
|
+
v_scale=layer.v_scale,
|
486
|
+
)
|
602
487
|
|
603
|
-
|
488
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
604
489
|
|
605
490
|
def _get_wrapper_idx(self, layer: RadixAttention):
|
606
491
|
if self.num_wrappers == 1:
|
@@ -648,11 +533,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
648
533
|
req_pool_indices: torch.Tensor,
|
649
534
|
seq_lens: torch.Tensor,
|
650
535
|
seq_lens_sum: int,
|
651
|
-
decode_wrappers: List[
|
652
|
-
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
653
|
-
],
|
536
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
654
537
|
encoder_lens: Optional[torch.Tensor],
|
655
|
-
spec_info: Optional[
|
538
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
656
539
|
):
|
657
540
|
# Keep the signature for type checking. It will be assigned during runtime.
|
658
541
|
raise NotImplementedError()
|
@@ -662,11 +545,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
662
545
|
req_pool_indices: torch.Tensor,
|
663
546
|
seq_lens: torch.Tensor,
|
664
547
|
seq_lens_sum: int,
|
665
|
-
decode_wrappers: List[
|
666
|
-
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
667
|
-
],
|
548
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
668
549
|
encoder_lens: Optional[torch.Tensor],
|
669
|
-
spec_info: Optional[
|
550
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
670
551
|
):
|
671
552
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
672
553
|
self.call_begin_forward(
|
@@ -686,7 +567,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
686
567
|
seq_lens_sum: int,
|
687
568
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
688
569
|
encoder_lens: Optional[torch.Tensor],
|
689
|
-
spec_info: Optional[
|
570
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
690
571
|
):
|
691
572
|
for wrapper_id in range(2):
|
692
573
|
if wrapper_id == 0:
|
@@ -720,7 +601,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
720
601
|
seq_lens_sum: int,
|
721
602
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
722
603
|
encoder_lens: Optional[torch.Tensor],
|
723
|
-
spec_info: Optional[
|
604
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
724
605
|
):
|
725
606
|
for wrapper_id in range(2):
|
726
607
|
if wrapper_id == 0:
|
@@ -745,23 +626,27 @@ class FlashInferIndicesUpdaterDecode:
|
|
745
626
|
|
746
627
|
def call_begin_forward(
|
747
628
|
self,
|
748
|
-
wrapper:
|
749
|
-
BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
750
|
-
],
|
629
|
+
wrapper: BatchDecodeWithPagedKVCacheWrapper,
|
751
630
|
req_pool_indices: torch.Tensor,
|
752
631
|
paged_kernel_lens: torch.Tensor,
|
753
632
|
paged_kernel_lens_sum: int,
|
754
633
|
kv_indptr: torch.Tensor,
|
755
634
|
kv_start_idx: torch.Tensor,
|
756
|
-
spec_info: Optional[
|
635
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
757
636
|
):
|
758
637
|
if spec_info is None:
|
759
638
|
bs = len(req_pool_indices)
|
760
639
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
761
640
|
kv_indptr = kv_indptr[: bs + 1]
|
762
|
-
|
763
|
-
|
764
|
-
|
641
|
+
|
642
|
+
if wrapper.is_cuda_graph_enabled:
|
643
|
+
# Directly write to the cuda graph input buffer
|
644
|
+
kv_indices = wrapper._paged_kv_indices_buf
|
645
|
+
else:
|
646
|
+
kv_indices = torch.empty(
|
647
|
+
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
648
|
+
)
|
649
|
+
|
765
650
|
create_flashinfer_kv_indices_triton[(bs,)](
|
766
651
|
self.req_to_token,
|
767
652
|
req_pool_indices,
|
@@ -775,37 +660,18 @@ class FlashInferIndicesUpdaterDecode:
|
|
775
660
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
776
661
|
bs = kv_indptr.shape[0] - 1
|
777
662
|
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
1,
|
791
|
-
False,
|
792
|
-
sm_scale,
|
793
|
-
self.data_type,
|
794
|
-
self.data_type,
|
795
|
-
)
|
796
|
-
else:
|
797
|
-
wrapper.begin_forward(
|
798
|
-
kv_indptr,
|
799
|
-
kv_indices,
|
800
|
-
self.kv_last_page_len[:bs],
|
801
|
-
self.num_qo_heads,
|
802
|
-
self.num_kv_heads,
|
803
|
-
self.head_dim,
|
804
|
-
1,
|
805
|
-
data_type=self.data_type,
|
806
|
-
q_data_type=self.q_data_type,
|
807
|
-
non_blocking=True,
|
808
|
-
)
|
663
|
+
wrapper.begin_forward(
|
664
|
+
kv_indptr,
|
665
|
+
kv_indices,
|
666
|
+
self.kv_last_page_len[:bs],
|
667
|
+
self.num_qo_heads,
|
668
|
+
self.num_kv_heads,
|
669
|
+
self.head_dim,
|
670
|
+
1,
|
671
|
+
data_type=self.data_type,
|
672
|
+
q_data_type=self.q_data_type,
|
673
|
+
non_blocking=True,
|
674
|
+
)
|
809
675
|
|
810
676
|
|
811
677
|
class FlashInferIndicesUpdaterPrefill:
|
@@ -841,32 +707,28 @@ class FlashInferIndicesUpdaterPrefill:
|
|
841
707
|
|
842
708
|
def update(
|
843
709
|
self,
|
844
|
-
req_pool_indices: torch.
|
710
|
+
req_pool_indices: torch.Tensor,
|
845
711
|
seq_lens: torch.Tensor,
|
846
712
|
seq_lens_sum: int,
|
847
713
|
prefix_lens: torch.Tensor,
|
848
|
-
prefill_wrappers: List[
|
849
|
-
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
850
|
-
],
|
714
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
851
715
|
use_ragged: bool,
|
852
716
|
encoder_lens: Optional[torch.Tensor],
|
853
|
-
spec_info: Optional[
|
717
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
854
718
|
):
|
855
719
|
# Keep the signature for type checking. It will be assigned during runtime.
|
856
720
|
raise NotImplementedError()
|
857
721
|
|
858
722
|
def update_single_wrapper(
|
859
723
|
self,
|
860
|
-
req_pool_indices: torch.
|
724
|
+
req_pool_indices: torch.Tensor,
|
861
725
|
seq_lens: torch.Tensor,
|
862
726
|
seq_lens_sum: int,
|
863
727
|
prefix_lens: torch.Tensor,
|
864
|
-
prefill_wrappers: List[
|
865
|
-
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
866
|
-
],
|
728
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
867
729
|
use_ragged: bool,
|
868
730
|
encoder_lens: Optional[torch.Tensor],
|
869
|
-
spec_info: Optional[
|
731
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
870
732
|
):
|
871
733
|
if use_ragged:
|
872
734
|
paged_kernel_lens = prefix_lens
|
@@ -899,7 +761,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
899
761
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
900
762
|
use_ragged: bool,
|
901
763
|
encoder_lens: Optional[torch.Tensor],
|
902
|
-
spec_info: Optional[
|
764
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
903
765
|
):
|
904
766
|
for wrapper_id in range(2):
|
905
767
|
if wrapper_id == 0:
|
@@ -940,7 +802,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
940
802
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
941
803
|
use_ragged: bool,
|
942
804
|
encoder_lens: Optional[torch.Tensor],
|
943
|
-
spec_info: Optional[
|
805
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
944
806
|
):
|
945
807
|
for wrapper_id in range(2):
|
946
808
|
if wrapper_id == 0:
|
@@ -972,9 +834,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
972
834
|
def call_begin_forward(
|
973
835
|
self,
|
974
836
|
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
975
|
-
wrapper_paged:
|
976
|
-
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
977
|
-
],
|
837
|
+
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
|
978
838
|
req_pool_indices: torch.Tensor,
|
979
839
|
paged_kernel_lens: torch.Tensor,
|
980
840
|
paged_kernel_lens_sum: int,
|
@@ -984,10 +844,11 @@ class FlashInferIndicesUpdaterPrefill:
|
|
984
844
|
kv_indptr: torch.Tensor,
|
985
845
|
qo_indptr: torch.Tensor,
|
986
846
|
use_ragged: bool,
|
987
|
-
spec_info: Optional[
|
847
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
988
848
|
):
|
989
|
-
bs = len(
|
849
|
+
bs = len(seq_lens)
|
990
850
|
if spec_info is None:
|
851
|
+
assert len(seq_lens) == len(req_pool_indices)
|
991
852
|
# Normal extend
|
992
853
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
993
854
|
kv_indptr = kv_indptr[: bs + 1]
|
@@ -1005,77 +866,54 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1005
866
|
kv_indices,
|
1006
867
|
self.req_to_token.shape[1],
|
1007
868
|
)
|
1008
|
-
|
1009
869
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
1010
870
|
qo_indptr = qo_indptr[: bs + 1]
|
1011
871
|
custom_mask = None
|
1012
872
|
else:
|
873
|
+
assert isinstance(spec_info, EagleDraftInput) or isinstance(
|
874
|
+
spec_info, EagleVerifyInput
|
875
|
+
)
|
1013
876
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
1014
877
|
spec_info.generate_attn_arg_prefill(
|
1015
878
|
req_pool_indices,
|
1016
879
|
paged_kernel_lens,
|
880
|
+
paged_kernel_lens_sum,
|
1017
881
|
self.req_to_token,
|
1018
882
|
)
|
1019
883
|
)
|
1020
884
|
|
1021
885
|
# extend part
|
1022
886
|
if use_ragged:
|
1023
|
-
|
1024
|
-
|
1025
|
-
qo_indptr=qo_indptr,
|
1026
|
-
kv_indptr=qo_indptr,
|
1027
|
-
num_qo_heads=self.num_qo_heads,
|
1028
|
-
num_kv_heads=self.num_kv_heads,
|
1029
|
-
head_dim_qk=192,
|
1030
|
-
head_dim_vo=128,
|
1031
|
-
q_data_type=self.q_data_type,
|
1032
|
-
)
|
1033
|
-
else:
|
1034
|
-
wrapper_ragged.begin_forward(
|
1035
|
-
qo_indptr,
|
1036
|
-
qo_indptr,
|
1037
|
-
self.num_qo_heads,
|
1038
|
-
self.num_kv_heads,
|
1039
|
-
self.head_dim,
|
1040
|
-
q_data_type=self.q_data_type,
|
1041
|
-
)
|
1042
|
-
|
1043
|
-
if not global_config.enable_flashinfer_mla:
|
1044
|
-
# cached part
|
1045
|
-
wrapper_paged.begin_forward(
|
887
|
+
wrapper_ragged.begin_forward(
|
888
|
+
qo_indptr,
|
1046
889
|
qo_indptr,
|
1047
|
-
kv_indptr,
|
1048
|
-
kv_indices,
|
1049
|
-
self.kv_last_page_len[:bs],
|
1050
890
|
self.num_qo_heads,
|
1051
891
|
self.num_kv_heads,
|
1052
892
|
self.head_dim,
|
1053
|
-
1,
|
1054
893
|
q_data_type=self.q_data_type,
|
1055
|
-
custom_mask=custom_mask,
|
1056
|
-
non_blocking=True,
|
1057
|
-
)
|
1058
|
-
elif (
|
1059
|
-
global_config.enable_flashinfer_mla
|
1060
|
-
and not global_server_args_dict["disable_radix_cache"]
|
1061
|
-
):
|
1062
|
-
# mla paged prefill
|
1063
|
-
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
|
1064
|
-
wrapper_paged.plan(
|
1065
|
-
qo_indptr,
|
1066
|
-
kv_indptr,
|
1067
|
-
kv_indices,
|
1068
|
-
kv_len_arr,
|
1069
|
-
self.num_qo_heads,
|
1070
|
-
512,
|
1071
|
-
64,
|
1072
|
-
1,
|
1073
|
-
True,
|
1074
|
-
1 / math.sqrt(192),
|
1075
|
-
self.data_type,
|
1076
|
-
self.data_type,
|
1077
894
|
)
|
1078
895
|
|
896
|
+
# cached part
|
897
|
+
wrapper_paged.begin_forward(
|
898
|
+
qo_indptr,
|
899
|
+
kv_indptr,
|
900
|
+
kv_indices,
|
901
|
+
self.kv_last_page_len[:bs],
|
902
|
+
self.num_qo_heads,
|
903
|
+
self.num_kv_heads,
|
904
|
+
self.head_dim,
|
905
|
+
1,
|
906
|
+
q_data_type=self.q_data_type,
|
907
|
+
kv_data_type=self.data_type,
|
908
|
+
custom_mask=custom_mask,
|
909
|
+
non_blocking=True,
|
910
|
+
)
|
911
|
+
|
912
|
+
|
913
|
+
# Use as a fast path to override the indptr in flashinfer's plan function
|
914
|
+
# This is used to remove some host-to-device copy overhead.
|
915
|
+
global global_override_indptr_cpu
|
916
|
+
|
1079
917
|
|
1080
918
|
class FlashInferMultiStepDraftBackend:
|
1081
919
|
"""
|
@@ -1094,7 +932,8 @@ class FlashInferMultiStepDraftBackend:
|
|
1094
932
|
self.topk = topk
|
1095
933
|
self.speculative_num_steps = speculative_num_steps
|
1096
934
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
1097
|
-
|
935
|
+
|
936
|
+
max_bs = model_runner.req_to_token_pool.size * self.topk
|
1098
937
|
self.kv_indptr = torch.zeros(
|
1099
938
|
(
|
1100
939
|
self.speculative_num_steps,
|
@@ -1103,6 +942,9 @@ class FlashInferMultiStepDraftBackend:
|
|
1103
942
|
dtype=torch.int32,
|
1104
943
|
device=model_runner.device,
|
1105
944
|
)
|
945
|
+
self.kv_last_page_len = torch.ones(
|
946
|
+
(max_bs,), dtype=torch.int32, device=model_runner.device
|
947
|
+
)
|
1106
948
|
self.attn_backends = []
|
1107
949
|
for i in range(self.speculative_num_steps):
|
1108
950
|
self.attn_backends.append(
|
@@ -1110,14 +952,20 @@ class FlashInferMultiStepDraftBackend:
|
|
1110
952
|
model_runner,
|
1111
953
|
skip_prefill=True,
|
1112
954
|
kv_indptr_buf=self.kv_indptr[i],
|
955
|
+
kv_last_page_len_buf=self.kv_last_page_len,
|
1113
956
|
)
|
1114
957
|
)
|
958
|
+
|
1115
959
|
self.max_context_len = self.attn_backends[0].max_context_len
|
960
|
+
|
1116
961
|
# Cached variables for generate_draft_decode_kv_indices
|
1117
962
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
1118
963
|
|
1119
964
|
def common_template(
|
1120
|
-
self,
|
965
|
+
self,
|
966
|
+
forward_batch: ForwardBatch,
|
967
|
+
kv_indices_buffer: torch.Tensor,
|
968
|
+
call_fn: Callable,
|
1121
969
|
):
|
1122
970
|
num_seqs = forward_batch.batch_size
|
1123
971
|
bs = self.topk * num_seqs
|
@@ -1142,13 +990,23 @@ class FlashInferMultiStepDraftBackend:
|
|
1142
990
|
triton.next_power_of_2(bs),
|
1143
991
|
)
|
1144
992
|
|
993
|
+
assert forward_batch.spec_info is not None
|
994
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
995
|
+
|
996
|
+
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
|
997
|
+
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
|
998
|
+
global global_override_indptr_cpu
|
999
|
+
|
1145
1000
|
for i in range(self.speculative_num_steps - 1):
|
1146
1001
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
1147
1002
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
1148
1003
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
1149
1004
|
]
|
1005
|
+
global_override_indptr_cpu = indptr_cpu_whole[i]
|
1150
1006
|
call_fn(i, forward_batch)
|
1151
1007
|
|
1008
|
+
global_override_indptr_cpu = None
|
1009
|
+
|
1152
1010
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
1153
1011
|
kv_indices = torch.zeros(
|
1154
1012
|
(
|
@@ -1160,6 +1018,8 @@ class FlashInferMultiStepDraftBackend:
|
|
1160
1018
|
)
|
1161
1019
|
|
1162
1020
|
def call_fn(i, forward_batch):
|
1021
|
+
assert forward_batch.spec_info is not None
|
1022
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
1163
1023
|
forward_batch.spec_info.kv_indptr = (
|
1164
1024
|
forward_batch.spec_info.kv_indptr.clone()
|
1165
1025
|
)
|
@@ -1176,6 +1036,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1176
1036
|
dtype=torch.int32,
|
1177
1037
|
device="cuda",
|
1178
1038
|
)
|
1039
|
+
|
1179
1040
|
for i in range(self.speculative_num_steps):
|
1180
1041
|
self.attn_backends[i].init_cuda_graph_state(
|
1181
1042
|
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
@@ -1192,65 +1053,27 @@ class FlashInferMultiStepDraftBackend:
|
|
1192
1053
|
forward_mode=ForwardMode.DECODE,
|
1193
1054
|
spec_info=forward_batch.spec_info,
|
1194
1055
|
)
|
1195
|
-
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
|
1196
|
-
forward_batch.batch_size
|
1197
|
-
][0]
|
1198
|
-
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
1199
1056
|
|
1200
1057
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
1201
1058
|
|
1202
|
-
def init_forward_metadata_replay_cuda_graph(
|
1059
|
+
def init_forward_metadata_replay_cuda_graph(
|
1060
|
+
self, forward_batch: ForwardBatch, bs: int
|
1061
|
+
):
|
1203
1062
|
def call_fn(i, forward_batch):
|
1204
1063
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
1205
|
-
|
1064
|
+
bs,
|
1206
1065
|
forward_batch.req_pool_indices,
|
1207
1066
|
forward_batch.seq_lens,
|
1208
1067
|
seq_lens_sum=-1,
|
1209
1068
|
encoder_lens=None,
|
1210
1069
|
forward_mode=ForwardMode.DECODE,
|
1211
1070
|
spec_info=forward_batch.spec_info,
|
1071
|
+
seq_lens_cpu=None,
|
1212
1072
|
)
|
1213
1073
|
|
1214
1074
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
1215
1075
|
|
1216
1076
|
|
1217
|
-
@triton.jit
|
1218
|
-
def create_flashinfer_kv_indices_triton(
|
1219
|
-
req_to_token_ptr, # [max_batch, max_context_len]
|
1220
|
-
req_pool_indices_ptr,
|
1221
|
-
page_kernel_lens_ptr,
|
1222
|
-
kv_indptr,
|
1223
|
-
kv_start_idx,
|
1224
|
-
kv_indices_ptr,
|
1225
|
-
req_to_token_ptr_stride: tl.constexpr,
|
1226
|
-
):
|
1227
|
-
BLOCK_SIZE: tl.constexpr = 512
|
1228
|
-
pid = tl.program_id(axis=0)
|
1229
|
-
|
1230
|
-
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
1231
|
-
kv_indices_offset = tl.load(kv_indptr + pid)
|
1232
|
-
|
1233
|
-
kv_start = 0
|
1234
|
-
kv_end = 0
|
1235
|
-
if kv_start_idx:
|
1236
|
-
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
1237
|
-
kv_end = kv_start
|
1238
|
-
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
1239
|
-
|
1240
|
-
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
1241
|
-
for i in range(num_loop):
|
1242
|
-
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
1243
|
-
mask = offset < kv_end - kv_start
|
1244
|
-
data = tl.load(
|
1245
|
-
req_to_token_ptr
|
1246
|
-
+ req_pool_index * req_to_token_ptr_stride
|
1247
|
-
+ kv_start
|
1248
|
-
+ offset,
|
1249
|
-
mask=mask,
|
1250
|
-
)
|
1251
|
-
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
1252
|
-
|
1253
|
-
|
1254
1077
|
def should_use_tensor_core(
|
1255
1078
|
kv_cache_dtype: torch.dtype,
|
1256
1079
|
num_attention_heads: int,
|
@@ -1272,6 +1095,21 @@ def should_use_tensor_core(
|
|
1272
1095
|
if env_override is not None:
|
1273
1096
|
return env_override.lower() == "true"
|
1274
1097
|
|
1098
|
+
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
1099
|
+
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
1100
|
+
try:
|
1101
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
1102
|
+
|
1103
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
1104
|
+
num_attention_heads,
|
1105
|
+
num_kv_heads,
|
1106
|
+
):
|
1107
|
+
return True
|
1108
|
+
else:
|
1109
|
+
return False
|
1110
|
+
except (ImportError, AttributeError):
|
1111
|
+
pass
|
1112
|
+
|
1275
1113
|
# Calculate GQA group size
|
1276
1114
|
gqa_group_size = num_attention_heads // num_kv_heads
|
1277
1115
|
|
@@ -1284,6 +1122,11 @@ def should_use_tensor_core(
|
|
1284
1122
|
return False
|
1285
1123
|
|
1286
1124
|
|
1125
|
+
# Use as a fast path to override the indptr in flashinfer's plan function
|
1126
|
+
# This is used to remove some host-to-device copy overhead.
|
1127
|
+
global_override_indptr_cpu = None
|
1128
|
+
|
1129
|
+
|
1287
1130
|
def fast_decode_plan(
|
1288
1131
|
self,
|
1289
1132
|
indptr: torch.Tensor,
|
@@ -1301,12 +1144,21 @@ def fast_decode_plan(
|
|
1301
1144
|
sm_scale: Optional[float] = None,
|
1302
1145
|
rope_scale: Optional[float] = None,
|
1303
1146
|
rope_theta: Optional[float] = None,
|
1304
|
-
|
1147
|
+
non_blocking: bool = True,
|
1305
1148
|
) -> None:
|
1306
|
-
"""
|
1149
|
+
"""
|
1150
|
+
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
1151
|
+
Modifications:
|
1152
|
+
- Remove unnecessary device-to-device copy for the cuda graph buffers.
|
1153
|
+
- Remove unnecessary host-to-device copy for the metadata buffers.
|
1154
|
+
"""
|
1307
1155
|
batch_size = len(last_page_len)
|
1308
1156
|
if logits_soft_cap is None:
|
1309
1157
|
logits_soft_cap = 0.0
|
1158
|
+
|
1159
|
+
if self.use_tensor_cores:
|
1160
|
+
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
1161
|
+
|
1310
1162
|
if self.is_cuda_graph_enabled:
|
1311
1163
|
if batch_size != self._fixed_batch_size:
|
1312
1164
|
raise ValueError(
|
@@ -1319,13 +1171,20 @@ def fast_decode_plan(
|
|
1319
1171
|
raise ValueError(
|
1320
1172
|
"The size of indices should be less than or equal to the allocated buffer"
|
1321
1173
|
)
|
1174
|
+
# Skip these copies because we directly write to them during prepartion
|
1175
|
+
# self._paged_kv_indptr_buf.copy_(indptr)
|
1176
|
+
# self._paged_kv_indices_buf[: len(indices)] = indices
|
1177
|
+
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
1322
1178
|
else:
|
1323
1179
|
self._paged_kv_indptr_buf = indptr
|
1324
1180
|
self._paged_kv_indices_buf = indices
|
1325
1181
|
self._paged_kv_last_page_len_buf = last_page_len
|
1182
|
+
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
|
1183
|
+
|
1326
1184
|
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
1327
1185
|
if not q_data_type:
|
1328
1186
|
q_data_type = data_type
|
1187
|
+
|
1329
1188
|
if not hasattr(self, "empty_q_data"):
|
1330
1189
|
self.empty_q_data = torch.empty(
|
1331
1190
|
0,
|
@@ -1342,27 +1201,56 @@ def fast_decode_plan(
|
|
1342
1201
|
),
|
1343
1202
|
)
|
1344
1203
|
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
self._int_workspace_buffer,
|
1351
|
-
self._pin_memory_int_workspace_buffer,
|
1352
|
-
indptr.to("cpu"),
|
1353
|
-
batch_size,
|
1354
|
-
num_qo_heads,
|
1355
|
-
num_kv_heads,
|
1356
|
-
page_size,
|
1357
|
-
self.is_cuda_graph_enabled,
|
1358
|
-
window_left,
|
1359
|
-
logits_soft_cap,
|
1360
|
-
head_dim,
|
1361
|
-
head_dim,
|
1362
|
-
empty_q_data,
|
1363
|
-
empty_kv_cache,
|
1364
|
-
stream.cuda_stream,
|
1204
|
+
|
1205
|
+
indptr_host = (
|
1206
|
+
global_override_indptr_cpu
|
1207
|
+
if global_override_indptr_cpu is not None
|
1208
|
+
else indptr.cpu()
|
1365
1209
|
)
|
1210
|
+
|
1211
|
+
if self.use_tensor_cores:
|
1212
|
+
kv_lens_arr_host = get_seq_lens(
|
1213
|
+
indptr_host, self.last_page_len[:batch_size], page_size
|
1214
|
+
)
|
1215
|
+
|
1216
|
+
self._plan_info = self._cached_module.plan(
|
1217
|
+
self._float_workspace_buffer,
|
1218
|
+
self._int_workspace_buffer,
|
1219
|
+
self._pin_memory_int_workspace_buffer,
|
1220
|
+
qo_indptr_host,
|
1221
|
+
indptr_host,
|
1222
|
+
kv_lens_arr_host,
|
1223
|
+
batch_size, # total_num_rows
|
1224
|
+
batch_size,
|
1225
|
+
num_qo_heads,
|
1226
|
+
num_kv_heads,
|
1227
|
+
page_size,
|
1228
|
+
self.is_cuda_graph_enabled,
|
1229
|
+
head_dim,
|
1230
|
+
head_dim,
|
1231
|
+
False, # causal
|
1232
|
+
torch.cuda.current_stream().cuda_stream,
|
1233
|
+
)
|
1234
|
+
else:
|
1235
|
+
self._plan_info = self._cached_module.plan(
|
1236
|
+
self._float_workspace_buffer,
|
1237
|
+
self._int_workspace_buffer,
|
1238
|
+
self._pin_memory_int_workspace_buffer,
|
1239
|
+
indptr_host,
|
1240
|
+
batch_size,
|
1241
|
+
num_qo_heads,
|
1242
|
+
num_kv_heads,
|
1243
|
+
page_size,
|
1244
|
+
self.is_cuda_graph_enabled,
|
1245
|
+
window_left,
|
1246
|
+
logits_soft_cap,
|
1247
|
+
head_dim,
|
1248
|
+
head_dim,
|
1249
|
+
self.empty_q_data,
|
1250
|
+
self.empty_kv_cache,
|
1251
|
+
torch.cuda.current_stream().cuda_stream,
|
1252
|
+
)
|
1253
|
+
|
1366
1254
|
self._pos_encoding_mode = pos_encoding_mode
|
1367
1255
|
self._window_left = window_left
|
1368
1256
|
self._logits_soft_cap = logits_soft_cap
|