sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +220 -378
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +208 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,582 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Support attention backend for flashinfer MLA.
|
5
|
+
The flashinfer_mla_disable_ragged flag controls whether to use ragged prefill wrapper and defaults to be false.
|
6
|
+
When it's set to false, all wrappers are BatchMLAPaged wrapper.
|
7
|
+
When it's set to true, the backend uses BatchRagged and BatchMLAPaged wrapper for prefilling,
|
8
|
+
and uses BatchMLAPaged wrapper for decoding.
|
9
|
+
More details can be found in https://docs.flashinfer.ai/api/mla.html
|
10
|
+
"""
|
11
|
+
|
12
|
+
from dataclasses import dataclass
|
13
|
+
from functools import partial
|
14
|
+
from typing import TYPE_CHECKING, Optional, Union
|
15
|
+
|
16
|
+
import torch
|
17
|
+
|
18
|
+
from sglang.global_config import global_config
|
19
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
20
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
21
|
+
create_flashinfer_kv_indices_triton,
|
22
|
+
)
|
23
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
24
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
25
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
26
|
+
from sglang.srt.utils import is_flashinfer_available
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
30
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
31
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
32
|
+
|
33
|
+
if is_flashinfer_available():
|
34
|
+
from flashinfer import (
|
35
|
+
BatchMLAPagedAttentionWrapper,
|
36
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
@dataclass
|
41
|
+
class DecodeMetadata:
|
42
|
+
decode_wrapper: BatchMLAPagedAttentionWrapper
|
43
|
+
|
44
|
+
|
45
|
+
@dataclass
|
46
|
+
class PrefillMetadata:
|
47
|
+
prefill_wrapper: BatchMLAPagedAttentionWrapper
|
48
|
+
use_ragged: bool
|
49
|
+
|
50
|
+
|
51
|
+
# Reuse this workspace buffer across all flashinfer wrappers
|
52
|
+
global_workspace_buffer = None
|
53
|
+
|
54
|
+
|
55
|
+
class FlashInferMLAAttnBackend(AttentionBackend):
|
56
|
+
"""Flashinfer attention kernels."""
|
57
|
+
|
58
|
+
def __init__(
|
59
|
+
self,
|
60
|
+
model_runner: ModelRunner,
|
61
|
+
):
|
62
|
+
super().__init__()
|
63
|
+
|
64
|
+
# Parse constants
|
65
|
+
self.max_context_len = model_runner.model_config.context_len
|
66
|
+
self.device = model_runner.device
|
67
|
+
|
68
|
+
global_config.enable_flashinfer_mla = True
|
69
|
+
|
70
|
+
# Allocate buffers
|
71
|
+
global global_workspace_buffer
|
72
|
+
if global_workspace_buffer is None:
|
73
|
+
global_workspace_buffer = torch.empty(
|
74
|
+
global_config.flashinfer_workspace_size,
|
75
|
+
dtype=torch.uint8,
|
76
|
+
device=model_runner.device,
|
77
|
+
)
|
78
|
+
self.workspace_buffer = global_workspace_buffer
|
79
|
+
|
80
|
+
max_bs = model_runner.req_to_token_pool.size
|
81
|
+
self.kv_indptr = torch.zeros(
|
82
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
83
|
+
)
|
84
|
+
|
85
|
+
self.qo_indptr = torch.zeros(
|
86
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
87
|
+
)
|
88
|
+
|
89
|
+
self.q_indptr_decode = torch.arange(
|
90
|
+
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
91
|
+
)
|
92
|
+
|
93
|
+
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
94
|
+
self.workspace_buffer, "NHD"
|
95
|
+
)
|
96
|
+
|
97
|
+
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
|
98
|
+
self.workspace_buffer,
|
99
|
+
backend="auto",
|
100
|
+
)
|
101
|
+
|
102
|
+
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
|
103
|
+
self.workspace_buffer, backend="auto"
|
104
|
+
)
|
105
|
+
|
106
|
+
# Create indices updater
|
107
|
+
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
|
108
|
+
model_runner, self
|
109
|
+
)
|
110
|
+
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
|
111
|
+
model_runner, self
|
112
|
+
)
|
113
|
+
|
114
|
+
# Other metadata
|
115
|
+
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
116
|
+
self.decode_cuda_graph_metadata = {}
|
117
|
+
self.prefill_cuda_graph_metadata = {}
|
118
|
+
|
119
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
120
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
121
|
+
self.indices_updater_decode.update(
|
122
|
+
forward_batch.req_pool_indices,
|
123
|
+
forward_batch.seq_lens,
|
124
|
+
forward_batch.seq_lens_sum,
|
125
|
+
decode_wrapper=self.decode_wrapper,
|
126
|
+
init_metadata_replay=False,
|
127
|
+
)
|
128
|
+
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
|
129
|
+
else:
|
130
|
+
prefix_lens = forward_batch.extend_prefix_lens
|
131
|
+
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
132
|
+
use_ragged = (
|
133
|
+
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
134
|
+
and extend_no_prefix
|
135
|
+
)
|
136
|
+
|
137
|
+
self.indices_updater_prefill.update(
|
138
|
+
forward_batch.req_pool_indices,
|
139
|
+
forward_batch.seq_lens,
|
140
|
+
forward_batch.seq_lens_sum,
|
141
|
+
prefix_lens,
|
142
|
+
prefill_wrapper_paged=self.prefill_wrapper_paged,
|
143
|
+
use_ragged=use_ragged,
|
144
|
+
)
|
145
|
+
self.forward_metadata = PrefillMetadata(
|
146
|
+
self.prefill_wrapper_paged, use_ragged
|
147
|
+
)
|
148
|
+
|
149
|
+
def init_cuda_graph_state(
|
150
|
+
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
151
|
+
):
|
152
|
+
if kv_indices_buf is None:
|
153
|
+
cuda_graph_kv_indices = torch.zeros(
|
154
|
+
(max_bs * self.max_context_len,),
|
155
|
+
dtype=torch.int32,
|
156
|
+
device="cuda",
|
157
|
+
)
|
158
|
+
else:
|
159
|
+
cuda_graph_kv_indices = kv_indices_buf
|
160
|
+
|
161
|
+
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
162
|
+
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
|
163
|
+
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
|
164
|
+
self.cuda_graph_kv_lens = torch.ones(
|
165
|
+
(max_bs,), dtype=torch.int32, device=self.device
|
166
|
+
)
|
167
|
+
|
168
|
+
# For fast decode plan in graph replaying
|
169
|
+
self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu")
|
170
|
+
self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu")
|
171
|
+
self.fast_decode_kwargs = {
|
172
|
+
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu,
|
173
|
+
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu,
|
174
|
+
"kv_indices": self.cuda_graph_kv_indices,
|
175
|
+
}
|
176
|
+
|
177
|
+
def init_forward_metadata_capture_cuda_graph(
|
178
|
+
self,
|
179
|
+
bs: int,
|
180
|
+
num_tokens: int,
|
181
|
+
req_pool_indices: torch.Tensor,
|
182
|
+
seq_lens: torch.Tensor,
|
183
|
+
encoder_lens: Optional[torch.Tensor],
|
184
|
+
forward_mode: ForwardMode,
|
185
|
+
spec_info: Optional[SpecInfo],
|
186
|
+
):
|
187
|
+
if forward_mode.is_decode_or_idle():
|
188
|
+
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
189
|
+
self.workspace_buffer,
|
190
|
+
use_cuda_graph=True,
|
191
|
+
qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1],
|
192
|
+
kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1],
|
193
|
+
kv_indices=self.cuda_graph_kv_indices,
|
194
|
+
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
|
195
|
+
backend="auto",
|
196
|
+
)
|
197
|
+
|
198
|
+
seq_lens_sum = seq_lens.sum().item()
|
199
|
+
self.indices_updater_decode.update(
|
200
|
+
req_pool_indices,
|
201
|
+
seq_lens,
|
202
|
+
seq_lens_sum,
|
203
|
+
decode_wrapper=decode_wrapper,
|
204
|
+
init_metadata_replay=False,
|
205
|
+
)
|
206
|
+
self.decode_cuda_graph_metadata[bs] = decode_wrapper
|
207
|
+
self.forward_metadata = DecodeMetadata(decode_wrapper)
|
208
|
+
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
|
209
|
+
else:
|
210
|
+
raise ValueError(f"Invalid mode: {forward_mode=}")
|
211
|
+
|
212
|
+
def init_forward_metadata_replay_cuda_graph(
|
213
|
+
self,
|
214
|
+
bs: int,
|
215
|
+
req_pool_indices: torch.Tensor,
|
216
|
+
seq_lens: torch.Tensor,
|
217
|
+
seq_lens_sum: int,
|
218
|
+
encoder_lens: Optional[torch.Tensor],
|
219
|
+
forward_mode: ForwardMode,
|
220
|
+
spec_info: Optional[SpecInfo],
|
221
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
222
|
+
):
|
223
|
+
if forward_mode.is_decode_or_idle():
|
224
|
+
kv_len_arr_cpu = seq_lens_cpu[:bs]
|
225
|
+
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
|
226
|
+
kv_len_arr_cpu, dim=0
|
227
|
+
)
|
228
|
+
self.fast_decode_kwargs.update(
|
229
|
+
{
|
230
|
+
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1],
|
231
|
+
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1],
|
232
|
+
"kv_len_arr_cpu": kv_len_arr_cpu,
|
233
|
+
}
|
234
|
+
)
|
235
|
+
|
236
|
+
self.indices_updater_decode.update(
|
237
|
+
req_pool_indices[:bs],
|
238
|
+
seq_lens[:bs],
|
239
|
+
seq_lens_sum,
|
240
|
+
decode_wrapper=self.decode_cuda_graph_metadata[bs],
|
241
|
+
init_metadata_replay=True,
|
242
|
+
**self.fast_decode_kwargs,
|
243
|
+
)
|
244
|
+
else:
|
245
|
+
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
246
|
+
|
247
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
248
|
+
return 0
|
249
|
+
|
250
|
+
def forward_extend(
|
251
|
+
self,
|
252
|
+
q: torch.Tensor,
|
253
|
+
k: torch.Tensor,
|
254
|
+
v: torch.Tensor,
|
255
|
+
layer: RadixAttention,
|
256
|
+
forward_batch: ForwardBatch,
|
257
|
+
save_kv_cache=True,
|
258
|
+
):
|
259
|
+
|
260
|
+
cache_loc = forward_batch.out_cache_loc
|
261
|
+
logits_soft_cap = layer.logit_cap
|
262
|
+
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
263
|
+
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
264
|
+
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
265
|
+
|
266
|
+
# Save kv cache
|
267
|
+
if save_kv_cache and k is not None:
|
268
|
+
assert v is not None
|
269
|
+
if save_kv_cache:
|
270
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
271
|
+
|
272
|
+
if self.forward_metadata.use_ragged:
|
273
|
+
# ragged prefill
|
274
|
+
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
275
|
+
qall,
|
276
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
277
|
+
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
278
|
+
causal=True,
|
279
|
+
sm_scale=layer.scaling,
|
280
|
+
logits_soft_cap=logits_soft_cap,
|
281
|
+
)
|
282
|
+
else:
|
283
|
+
# mla paged prefill
|
284
|
+
o = prefill_wrapper_paged.run(
|
285
|
+
qall[:, :, : layer.v_head_dim],
|
286
|
+
qall[:, :, layer.v_head_dim :],
|
287
|
+
k_buf[:, :, : layer.v_head_dim],
|
288
|
+
k_buf[:, :, layer.v_head_dim :],
|
289
|
+
)
|
290
|
+
|
291
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
292
|
+
|
293
|
+
def forward_decode(
|
294
|
+
self,
|
295
|
+
q: torch.Tensor,
|
296
|
+
k: torch.Tensor,
|
297
|
+
v: torch.Tensor,
|
298
|
+
layer: RadixAttention,
|
299
|
+
forward_batch: ForwardBatch,
|
300
|
+
save_kv_cache=True,
|
301
|
+
):
|
302
|
+
decode_wrapper = self.forward_metadata.decode_wrapper
|
303
|
+
cache_loc = forward_batch.out_cache_loc
|
304
|
+
|
305
|
+
if k is not None:
|
306
|
+
assert v is not None
|
307
|
+
if save_kv_cache:
|
308
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
309
|
+
layer,
|
310
|
+
cache_loc,
|
311
|
+
k,
|
312
|
+
v,
|
313
|
+
)
|
314
|
+
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
315
|
+
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
316
|
+
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
|
317
|
+
o = decode_wrapper.run(
|
318
|
+
reshaped_q[:, :, : layer.v_head_dim],
|
319
|
+
reshaped_q[:, :, layer.v_head_dim :],
|
320
|
+
reshaped_k[:, :, : layer.v_head_dim],
|
321
|
+
reshaped_k[:, :, layer.v_head_dim :],
|
322
|
+
)
|
323
|
+
|
324
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
325
|
+
|
326
|
+
|
327
|
+
class FlashInferMLAIndicesUpdaterDecode:
|
328
|
+
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
329
|
+
# Parse Constants
|
330
|
+
self.num_local_heads = (
|
331
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
332
|
+
)
|
333
|
+
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
334
|
+
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
335
|
+
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
336
|
+
self.scaling = model_runner.model_config.scaling
|
337
|
+
self.data_type = model_runner.kv_cache_dtype
|
338
|
+
self.attn_backend = attn_backend
|
339
|
+
|
340
|
+
# Buffers and wrappers
|
341
|
+
self.kv_indptr = attn_backend.kv_indptr
|
342
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
343
|
+
self.q_indptr = attn_backend.q_indptr_decode
|
344
|
+
|
345
|
+
def update(
|
346
|
+
self,
|
347
|
+
req_pool_indices: torch.Tensor,
|
348
|
+
seq_lens: torch.Tensor,
|
349
|
+
seq_lens_sum: int,
|
350
|
+
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
351
|
+
init_metadata_replay: bool = False,
|
352
|
+
**fast_decode_kwargs,
|
353
|
+
):
|
354
|
+
decode_wrapper = decode_wrapper or self.decode_wrapper
|
355
|
+
self.call_begin_forward(
|
356
|
+
decode_wrapper,
|
357
|
+
req_pool_indices,
|
358
|
+
seq_lens,
|
359
|
+
seq_lens_sum,
|
360
|
+
self.q_indptr,
|
361
|
+
self.kv_indptr,
|
362
|
+
init_metadata_replay,
|
363
|
+
**fast_decode_kwargs,
|
364
|
+
)
|
365
|
+
|
366
|
+
def call_begin_forward(
|
367
|
+
self,
|
368
|
+
wrapper: BatchMLAPagedAttentionWrapper,
|
369
|
+
req_pool_indices: torch.Tensor,
|
370
|
+
paged_kernel_lens: torch.Tensor,
|
371
|
+
paged_kernel_lens_sum: int,
|
372
|
+
q_indptr: torch.Tensor,
|
373
|
+
kv_indptr: torch.Tensor,
|
374
|
+
init_metadata_replay: bool = False,
|
375
|
+
**fast_decode_kwargs,
|
376
|
+
):
|
377
|
+
bs = len(req_pool_indices)
|
378
|
+
q_indptr = q_indptr[: bs + 1]
|
379
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
380
|
+
kv_indptr = kv_indptr[: bs + 1]
|
381
|
+
kv_indices = (
|
382
|
+
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
|
383
|
+
if not init_metadata_replay
|
384
|
+
else fast_decode_kwargs["kv_indices"]
|
385
|
+
)
|
386
|
+
|
387
|
+
kv_lens = paged_kernel_lens.to(torch.int32)
|
388
|
+
sm_scale = self.scaling
|
389
|
+
|
390
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
391
|
+
self.req_to_token,
|
392
|
+
req_pool_indices,
|
393
|
+
paged_kernel_lens,
|
394
|
+
kv_indptr,
|
395
|
+
None,
|
396
|
+
kv_indices,
|
397
|
+
self.req_to_token.shape[1],
|
398
|
+
)
|
399
|
+
if not init_metadata_replay:
|
400
|
+
wrapper.plan(
|
401
|
+
q_indptr,
|
402
|
+
kv_indptr,
|
403
|
+
kv_indices,
|
404
|
+
kv_lens,
|
405
|
+
self.num_local_heads,
|
406
|
+
self.kv_lora_rank,
|
407
|
+
self.qk_rope_head_dim,
|
408
|
+
1,
|
409
|
+
False,
|
410
|
+
sm_scale,
|
411
|
+
self.data_type,
|
412
|
+
self.data_type,
|
413
|
+
)
|
414
|
+
else:
|
415
|
+
wrapper.plan(
|
416
|
+
fast_decode_kwargs["qo_indptr_cpu"],
|
417
|
+
fast_decode_kwargs["kv_indptr_cpu"],
|
418
|
+
kv_indices,
|
419
|
+
fast_decode_kwargs["kv_len_arr_cpu"],
|
420
|
+
self.num_local_heads,
|
421
|
+
self.kv_lora_rank,
|
422
|
+
self.qk_rope_head_dim,
|
423
|
+
1,
|
424
|
+
False,
|
425
|
+
sm_scale,
|
426
|
+
self.data_type,
|
427
|
+
self.data_type,
|
428
|
+
)
|
429
|
+
|
430
|
+
|
431
|
+
class FlashInferMLAIndicesUpdaterPrefill:
|
432
|
+
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
433
|
+
# Parse Constants
|
434
|
+
self.num_local_heads = (
|
435
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
436
|
+
)
|
437
|
+
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
438
|
+
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
439
|
+
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
440
|
+
self.v_head_dim = model_runner.model_config.v_head_dim
|
441
|
+
self.scaling = model_runner.model_config.scaling
|
442
|
+
self.data_type = model_runner.kv_cache_dtype
|
443
|
+
self.q_data_type = model_runner.dtype
|
444
|
+
self.attn_backend = attn_backend
|
445
|
+
|
446
|
+
# Buffers and wrappers
|
447
|
+
self.kv_indptr = attn_backend.kv_indptr
|
448
|
+
self.qo_indptr = attn_backend.qo_indptr
|
449
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
450
|
+
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
451
|
+
|
452
|
+
def update(
|
453
|
+
self,
|
454
|
+
req_pool_indices: torch.Tnesor,
|
455
|
+
seq_lens: torch.Tensor,
|
456
|
+
seq_lens_sum: int,
|
457
|
+
prefix_lens: torch.Tensor,
|
458
|
+
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
|
459
|
+
use_ragged: bool,
|
460
|
+
):
|
461
|
+
if use_ragged:
|
462
|
+
paged_kernel_lens = prefix_lens
|
463
|
+
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
464
|
+
else:
|
465
|
+
paged_kernel_lens = seq_lens
|
466
|
+
paged_kernel_lens_sum = seq_lens_sum
|
467
|
+
|
468
|
+
self.call_begin_forward(
|
469
|
+
self.prefill_wrapper_ragged,
|
470
|
+
prefill_wrapper_paged,
|
471
|
+
req_pool_indices,
|
472
|
+
paged_kernel_lens,
|
473
|
+
paged_kernel_lens_sum,
|
474
|
+
seq_lens,
|
475
|
+
prefix_lens,
|
476
|
+
self.kv_indptr,
|
477
|
+
self.qo_indptr,
|
478
|
+
use_ragged,
|
479
|
+
)
|
480
|
+
|
481
|
+
def call_begin_forward(
|
482
|
+
self,
|
483
|
+
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
484
|
+
wrapper_paged: BatchMLAPagedAttentionWrapper,
|
485
|
+
req_pool_indices: torch.Tensor,
|
486
|
+
paged_kernel_lens: torch.Tensor,
|
487
|
+
paged_kernel_lens_sum: int,
|
488
|
+
seq_lens: torch.Tensor,
|
489
|
+
prefix_lens: torch.Tensor,
|
490
|
+
kv_indptr: torch.Tensor,
|
491
|
+
qo_indptr: torch.Tensor,
|
492
|
+
use_ragged: bool,
|
493
|
+
):
|
494
|
+
bs = len(req_pool_indices)
|
495
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
496
|
+
kv_indptr = kv_indptr[: bs + 1]
|
497
|
+
kv_indices = torch.empty(
|
498
|
+
paged_kernel_lens_sum,
|
499
|
+
dtype=torch.int32,
|
500
|
+
device=req_pool_indices.device,
|
501
|
+
)
|
502
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
503
|
+
self.req_to_token,
|
504
|
+
req_pool_indices,
|
505
|
+
paged_kernel_lens,
|
506
|
+
kv_indptr,
|
507
|
+
None,
|
508
|
+
kv_indices,
|
509
|
+
self.req_to_token.shape[1],
|
510
|
+
)
|
511
|
+
|
512
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
513
|
+
qo_indptr = qo_indptr[: bs + 1]
|
514
|
+
sm_scale = self.scaling
|
515
|
+
|
516
|
+
if use_ragged:
|
517
|
+
# ragged prefill
|
518
|
+
wrapper_ragged.begin_forward(
|
519
|
+
qo_indptr=qo_indptr,
|
520
|
+
kv_indptr=qo_indptr,
|
521
|
+
num_qo_heads=self.num_local_heads,
|
522
|
+
num_kv_heads=self.num_local_heads,
|
523
|
+
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
524
|
+
head_dim_vo=self.v_head_dim,
|
525
|
+
q_data_type=self.q_data_type,
|
526
|
+
)
|
527
|
+
else:
|
528
|
+
# mla paged prefill
|
529
|
+
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
|
530
|
+
wrapper_paged.plan(
|
531
|
+
qo_indptr,
|
532
|
+
kv_indptr,
|
533
|
+
kv_indices,
|
534
|
+
kv_len_arr,
|
535
|
+
self.num_local_heads,
|
536
|
+
self.kv_lora_rank,
|
537
|
+
self.qk_rope_head_dim,
|
538
|
+
1,
|
539
|
+
True,
|
540
|
+
sm_scale,
|
541
|
+
self.q_data_type,
|
542
|
+
self.data_type,
|
543
|
+
)
|
544
|
+
|
545
|
+
|
546
|
+
def fast_mla_decode_plan(
|
547
|
+
self,
|
548
|
+
qo_indptr_cpu: torch.Tensor,
|
549
|
+
kv_indptr_cpu: torch.Tensor,
|
550
|
+
kv_indices: torch.Tensor,
|
551
|
+
kv_len_arr_cpu: torch.Tensor,
|
552
|
+
num_heads: int,
|
553
|
+
head_dim_ckv: int,
|
554
|
+
head_dim_kpe: int,
|
555
|
+
page_size: int,
|
556
|
+
causal: bool,
|
557
|
+
sm_scale: float,
|
558
|
+
q_data_type: torch.dtype,
|
559
|
+
kv_data_type: torch.dtype,
|
560
|
+
) -> None:
|
561
|
+
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
|
562
|
+
for skipping the stream synchronization in original plan function during
|
563
|
+
cuda graph replaying.
|
564
|
+
"""
|
565
|
+
self._causal = causal
|
566
|
+
self._page_size = page_size
|
567
|
+
self._sm_scale = sm_scale
|
568
|
+
|
569
|
+
with self.device as device:
|
570
|
+
stream = torch.cuda.current_stream(device).cuda_stream
|
571
|
+
self._cached_module.plan(
|
572
|
+
self._float_workspace_buffer,
|
573
|
+
self._int_workspace_buffer,
|
574
|
+
self._pin_memory_int_workspace_buffer,
|
575
|
+
qo_indptr_cpu,
|
576
|
+
kv_indptr_cpu,
|
577
|
+
kv_len_arr_cpu,
|
578
|
+
num_heads,
|
579
|
+
head_dim_ckv,
|
580
|
+
causal,
|
581
|
+
stream,
|
582
|
+
)
|
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
|
5
5
|
import torch
|
6
6
|
from torch.nn.functional import scaled_dot_product_attention
|
7
7
|
|
8
|
-
from sglang.srt.layers.attention import AttentionBackend
|
8
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
9
9
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
@@ -1,11 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import TYPE_CHECKING, Optional
|
3
|
+
from typing import TYPE_CHECKING, Optional, Union
|
4
4
|
|
5
5
|
import torch
|
6
6
|
import triton
|
7
7
|
|
8
|
-
from sglang.srt.layers.attention import AttentionBackend
|
8
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
9
9
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
10
10
|
create_flashinfer_kv_indices_triton,
|
11
11
|
)
|
@@ -15,7 +15,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
|
|
15
15
|
if TYPE_CHECKING:
|
16
16
|
from sglang.srt.layers.radix_attention import RadixAttention
|
17
17
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
18
|
-
from sglang.srt.speculative.
|
18
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
19
19
|
|
20
20
|
|
21
21
|
class TritonAttnBackend(AttentionBackend):
|
@@ -156,6 +156,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
156
156
|
spec_info.generate_attn_arg_prefill(
|
157
157
|
forward_batch.req_pool_indices,
|
158
158
|
forward_batch.seq_lens,
|
159
|
+
None,
|
159
160
|
self.req_to_token,
|
160
161
|
)
|
161
162
|
)
|
@@ -232,7 +233,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
232
233
|
seq_lens: torch.Tensor,
|
233
234
|
encoder_lens: Optional[torch.Tensor],
|
234
235
|
forward_mode: ForwardMode,
|
235
|
-
spec_info: Optional[
|
236
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
236
237
|
):
|
237
238
|
assert encoder_lens is None, "Not supported"
|
238
239
|
|
@@ -310,7 +311,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
310
311
|
seq_lens_sum: int,
|
311
312
|
encoder_lens: Optional[torch.Tensor],
|
312
313
|
forward_mode: ForwardMode,
|
313
|
-
spec_info: Optional[
|
314
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
315
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
314
316
|
):
|
315
317
|
# NOTE: encoder_lens expected to be zeros or None
|
316
318
|
if forward_mode.is_decode_or_idle():
|
@@ -474,7 +476,7 @@ class TritonMultiStepDraftBackend:
|
|
474
476
|
self.topk = topk
|
475
477
|
self.speculative_num_steps = speculative_num_steps
|
476
478
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
477
|
-
max_bs = model_runner.req_to_token_pool.size
|
479
|
+
max_bs = model_runner.req_to_token_pool.size * self.topk
|
478
480
|
self.kv_indptr = torch.zeros(
|
479
481
|
(
|
480
482
|
self.speculative_num_steps,
|
@@ -586,6 +588,7 @@ class TritonMultiStepDraftBackend:
|
|
586
588
|
encoder_lens=None,
|
587
589
|
forward_mode=ForwardMode.DECODE,
|
588
590
|
spec_info=forward_batch.spec_info,
|
591
|
+
seq_lens_cpu=None,
|
589
592
|
)
|
590
593
|
|
591
594
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
@@ -635,6 +635,9 @@ def decode_attention_fwd(
|
|
635
635
|
logit_cap=0.0,
|
636
636
|
):
|
637
637
|
assert num_kv_splits == attn_logits.shape[2]
|
638
|
+
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
639
|
+
assert q.shape[0] <= attn_logits.shape[0]
|
640
|
+
|
638
641
|
kv_group_num = q.shape[1] // v_buffer.shape[1]
|
639
642
|
|
640
643
|
if kv_group_num == 1:
|