sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -10,23 +10,30 @@ from typing import TYPE_CHECKING, Optional
|
|
10
10
|
|
11
11
|
import torch
|
12
12
|
|
13
|
-
from sglang.srt.layers.attention.flashinfer_backend import
|
13
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
14
|
+
FlashInferAttnBackend,
|
15
|
+
FlashInferMultiStepDraftBackend,
|
16
|
+
)
|
14
17
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
15
18
|
from sglang.srt.utils import is_flashinfer_available
|
16
19
|
|
17
20
|
if is_flashinfer_available():
|
18
21
|
import flashinfer
|
19
22
|
|
23
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
24
|
+
|
20
25
|
if TYPE_CHECKING:
|
21
26
|
from sglang.srt.layers.radix_attention import RadixAttention
|
22
27
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
23
28
|
from sglang.srt.speculative.spec_info import SpecInfo
|
24
29
|
|
25
30
|
# Constants
|
26
|
-
DEFAULT_WORKSPACE_SIZE_MB =
|
31
|
+
DEFAULT_WORKSPACE_SIZE_MB = (
|
32
|
+
512 # Memory workspace size in MB, todo(Yingyi): read from config
|
33
|
+
)
|
27
34
|
|
28
35
|
# Reuse this workspace buffer across all TRTLLM MHA wrappers
|
29
|
-
|
36
|
+
global_zero_init_workspace_buffer = None
|
30
37
|
|
31
38
|
|
32
39
|
@dataclass
|
@@ -53,9 +60,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
53
60
|
model_runner: ModelRunner,
|
54
61
|
skip_prefill: bool = False,
|
55
62
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
56
|
-
|
63
|
+
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
64
|
+
speculative_step_id: int = 0,
|
57
65
|
):
|
58
|
-
super().__init__(
|
66
|
+
super().__init__(
|
67
|
+
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
|
68
|
+
)
|
59
69
|
|
60
70
|
config = model_runner.model_config
|
61
71
|
|
@@ -73,18 +83,28 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
73
83
|
# Workspace allocation
|
74
84
|
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
75
85
|
# Allocate buffers
|
76
|
-
global
|
77
|
-
if
|
78
|
-
|
86
|
+
global global_zero_init_workspace_buffer
|
87
|
+
if global_zero_init_workspace_buffer is None:
|
88
|
+
global_zero_init_workspace_buffer = torch.zeros(
|
79
89
|
self.workspace_size,
|
80
90
|
dtype=torch.uint8,
|
81
91
|
device=model_runner.device,
|
82
92
|
)
|
83
|
-
self.workspace_buffer =
|
93
|
+
self.workspace_buffer = global_zero_init_workspace_buffer
|
84
94
|
|
85
95
|
# CUDA graph state
|
86
96
|
self.decode_cuda_graph_metadata = {}
|
87
97
|
|
98
|
+
# Speculative decoding
|
99
|
+
# Only support topk <= 1 for now.
|
100
|
+
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
101
|
+
self.speculative_step_id = speculative_step_id
|
102
|
+
self.target_verify_metadata = {}
|
103
|
+
|
104
|
+
self.speculative_num_draft_tokens = (
|
105
|
+
model_runner.server_args.speculative_num_draft_tokens
|
106
|
+
)
|
107
|
+
|
88
108
|
# Forward metadata
|
89
109
|
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
|
90
110
|
|
@@ -95,11 +115,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
95
115
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
96
116
|
):
|
97
117
|
"""Initialize CUDA graph state for TRTLLM MHA."""
|
118
|
+
max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
|
98
119
|
self.decode_cuda_graph_metadata = {
|
99
120
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
100
121
|
"page_table": torch.zeros(
|
101
122
|
max_bs,
|
102
|
-
|
123
|
+
max_num_pages,
|
103
124
|
dtype=torch.int32,
|
104
125
|
device=self.device,
|
105
126
|
),
|
@@ -108,6 +129,70 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
108
129
|
),
|
109
130
|
}
|
110
131
|
|
132
|
+
if (
|
133
|
+
self.speculative_num_draft_tokens is not None
|
134
|
+
and self.speculative_num_draft_tokens > 0
|
135
|
+
):
|
136
|
+
self.decode_cuda_graph_metadata["cu_seqlens_q"] = torch.arange(
|
137
|
+
0, max_bs + 1, dtype=torch.int32, device=self.device
|
138
|
+
)
|
139
|
+
self.decode_cuda_graph_metadata["cu_seqlens_k"] = torch.zeros(
|
140
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
141
|
+
)
|
142
|
+
self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
|
143
|
+
max_bs,
|
144
|
+
max_num_pages,
|
145
|
+
dtype=torch.int32,
|
146
|
+
device=self.device,
|
147
|
+
)
|
148
|
+
self.target_verify_metadata = {
|
149
|
+
"cache_seqlens": torch.zeros(
|
150
|
+
max_bs, dtype=torch.int32, device=self.device
|
151
|
+
),
|
152
|
+
"cu_seqlens_q": torch.arange(
|
153
|
+
0,
|
154
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
155
|
+
step=self.speculative_num_draft_tokens,
|
156
|
+
dtype=torch.int32,
|
157
|
+
device=self.device,
|
158
|
+
),
|
159
|
+
"cu_seqlens_k": torch.zeros(
|
160
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
161
|
+
),
|
162
|
+
"page_table": torch.zeros(
|
163
|
+
max_bs,
|
164
|
+
max_num_pages,
|
165
|
+
dtype=torch.int32,
|
166
|
+
device=self.device,
|
167
|
+
),
|
168
|
+
"strided_indices": torch.arange(
|
169
|
+
0, self.max_context_len, self.page_size, device=self.device
|
170
|
+
),
|
171
|
+
}
|
172
|
+
|
173
|
+
self.draft_extend_metadata = {
|
174
|
+
"cache_seqlens": torch.zeros(
|
175
|
+
max_bs, dtype=torch.int32, device=self.device
|
176
|
+
),
|
177
|
+
"cu_seqlens_q": torch.zeros(
|
178
|
+
max_bs + 1,
|
179
|
+
dtype=torch.int32,
|
180
|
+
device=self.device,
|
181
|
+
),
|
182
|
+
"cu_seqlens_k": torch.zeros(
|
183
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
184
|
+
),
|
185
|
+
"page_table": torch.zeros(
|
186
|
+
max_bs,
|
187
|
+
max_num_pages,
|
188
|
+
dtype=torch.int32,
|
189
|
+
device=self.device,
|
190
|
+
),
|
191
|
+
"strided_indices": torch.arange(
|
192
|
+
0, self.max_context_len, self.page_size, device=self.device
|
193
|
+
),
|
194
|
+
}
|
195
|
+
|
111
196
|
def init_forward_metadata_capture_cuda_graph(
|
112
197
|
self,
|
113
198
|
bs: int,
|
@@ -120,16 +205,105 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
120
205
|
):
|
121
206
|
"""Initialize metadata for CUDA graph capture."""
|
122
207
|
metadata = TRTLLMMHAMetadata()
|
208
|
+
device = seq_lens.device
|
123
209
|
|
124
|
-
|
125
|
-
|
210
|
+
if forward_mode.is_decode_or_idle():
|
211
|
+
if spec_info is not None:
|
212
|
+
# Draft Decode
|
213
|
+
# Here we only support topk = 1 for now.
|
214
|
+
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
215
|
+
"cache_seqlens"
|
216
|
+
][:bs]
|
217
|
+
metadata.max_seq_len_k = seq_lens.max().item() + (
|
218
|
+
self.speculative_step_id + 1
|
219
|
+
)
|
220
|
+
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
|
221
|
+
: bs + 1
|
222
|
+
]
|
223
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
224
|
+
torch.cumsum(
|
225
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
226
|
+
),
|
227
|
+
(1, 0),
|
228
|
+
)
|
229
|
+
metadata.page_table = self.decode_cuda_graph_metadata[
|
230
|
+
"page_table_draft_decode"
|
231
|
+
][:bs, :]
|
232
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
233
|
+
else:
|
234
|
+
# Normal Decode
|
235
|
+
# Get sequence information
|
236
|
+
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
|
237
|
+
batch_size = len(seq_lens)
|
238
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
239
|
+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
240
|
+
)
|
126
241
|
|
127
|
-
|
128
|
-
|
242
|
+
# Precompute maximum sequence length
|
243
|
+
metadata.max_seq_len_k = seq_lens.max().item()
|
244
|
+
# Precompute cumulative sequence lengths
|
245
|
+
metadata.cu_seqlens_q = torch.arange(
|
246
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
247
|
+
)
|
248
|
+
# Precompute page table
|
249
|
+
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
250
|
+
:bs, :
|
251
|
+
]
|
252
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
253
|
+
elif forward_mode.is_target_verify():
|
254
|
+
# Target Verify
|
255
|
+
# Here we only support topk = 1 for now.
|
256
|
+
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
|
257
|
+
:bs
|
258
|
+
]
|
259
|
+
metadata.cache_seqlens_int32.copy_(
|
260
|
+
(seq_lens + self.speculative_num_draft_tokens)
|
261
|
+
)
|
129
262
|
|
130
|
-
|
131
|
-
|
132
|
-
|
263
|
+
metadata.cu_seqlens_q = torch.arange(
|
264
|
+
0,
|
265
|
+
bs * self.speculative_num_draft_tokens + 1,
|
266
|
+
self.speculative_num_draft_tokens,
|
267
|
+
dtype=torch.int32,
|
268
|
+
device=device,
|
269
|
+
)
|
270
|
+
|
271
|
+
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
|
272
|
+
: (bs + 1)
|
273
|
+
]
|
274
|
+
|
275
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
276
|
+
metadata.max_seq_len_k = (
|
277
|
+
seq_lens.max().item() + self.speculative_num_draft_tokens
|
278
|
+
)
|
279
|
+
|
280
|
+
metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
|
281
|
+
|
282
|
+
self.target_verify_metadata[bs] = metadata
|
283
|
+
elif forward_mode.is_draft_extend():
|
284
|
+
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
285
|
+
:bs
|
286
|
+
]
|
287
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
288
|
+
num_tokens_per_bs = num_tokens // bs
|
289
|
+
metadata.cu_seqlens_q = torch.arange(
|
290
|
+
0,
|
291
|
+
bs * num_tokens_per_bs + 1,
|
292
|
+
num_tokens_per_bs,
|
293
|
+
dtype=torch.int32,
|
294
|
+
device=device,
|
295
|
+
)
|
296
|
+
|
297
|
+
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
|
298
|
+
: (bs + 1)
|
299
|
+
]
|
300
|
+
num_tokens_per_bs = num_tokens // bs
|
301
|
+
metadata.max_seq_len_q = num_tokens_per_bs
|
302
|
+
metadata.max_seq_len_k = seq_lens.max().item()
|
303
|
+
|
304
|
+
metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
|
305
|
+
|
306
|
+
self.draft_extend_metadata[bs] = metadata
|
133
307
|
self.forward_metadata = metadata
|
134
308
|
|
135
309
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -147,21 +321,91 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
147
321
|
seq_lens = seq_lens[:bs]
|
148
322
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
149
323
|
req_pool_indices = req_pool_indices[:bs]
|
150
|
-
device = seq_lens.device
|
151
324
|
metadata = None
|
325
|
+
if forward_mode.is_decode_or_idle():
|
326
|
+
if spec_info is not None:
|
327
|
+
# Draft Decode
|
328
|
+
# Here we only support topk = 1 for now.
|
329
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
330
|
+
max_len = seq_lens_cpu.max().item()
|
331
|
+
metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
|
332
|
+
|
333
|
+
max_seq_pages = (
|
334
|
+
metadata.max_seq_len_k + self.page_size - 1
|
335
|
+
) // self.page_size
|
336
|
+
|
337
|
+
metadata.cache_seqlens_int32.copy_(
|
338
|
+
seq_lens + self.speculative_step_id + 1
|
339
|
+
)
|
340
|
+
else:
|
341
|
+
# Normal Decode
|
342
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
343
|
+
max_len = seq_lens_cpu.max().item()
|
344
|
+
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
345
|
+
metadata.max_seq_len_k = max_len
|
346
|
+
|
347
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
348
|
+
|
349
|
+
metadata.cu_seqlens_k[1:].copy_(
|
350
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
351
|
+
)
|
352
|
+
page_indices = self.req_to_token[
|
353
|
+
req_pool_indices[:, None],
|
354
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
|
355
|
+
None, :
|
356
|
+
],
|
357
|
+
]
|
358
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
359
|
+
elif forward_mode.is_target_verify():
|
360
|
+
# Here we only support topk = 1 for now.
|
361
|
+
metadata = self.target_verify_metadata[bs]
|
362
|
+
metadata.cache_seqlens_int32.copy_(
|
363
|
+
(seq_lens + self.speculative_num_draft_tokens)
|
364
|
+
)
|
152
365
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
self.
|
163
|
-
|
164
|
-
|
366
|
+
metadata.max_seq_len_k = (
|
367
|
+
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
368
|
+
)
|
369
|
+
max_len = seq_lens_cpu.max().item()
|
370
|
+
metadata.cu_seqlens_k[1:].copy_(
|
371
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
372
|
+
)
|
373
|
+
max_seq_pages = (
|
374
|
+
metadata.max_seq_len_k + self.page_size - 1
|
375
|
+
) // self.page_size
|
376
|
+
page_indices = self.req_to_token[
|
377
|
+
req_pool_indices[:, None],
|
378
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
379
|
+
]
|
380
|
+
page_indices //= self.page_size
|
381
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
382
|
+
elif forward_mode.is_draft_extend():
|
383
|
+
metadata = self.draft_extend_metadata[bs]
|
384
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
385
|
+
|
386
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
387
|
+
max_len = seq_lens_cpu.max().item()
|
388
|
+
metadata.cu_seqlens_k[1:].copy_(
|
389
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
390
|
+
)
|
391
|
+
accept_length = spec_info.accept_length[:bs]
|
392
|
+
if spec_info.accept_length_cpu:
|
393
|
+
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
|
394
|
+
else:
|
395
|
+
metadata.max_seq_len_q = 1
|
396
|
+
|
397
|
+
metadata.cu_seqlens_q[1:].copy_(
|
398
|
+
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
399
|
+
)
|
400
|
+
|
401
|
+
max_seq_pages = (
|
402
|
+
metadata.max_seq_len_k + self.page_size - 1
|
403
|
+
) // self.page_size
|
404
|
+
page_indices = self.req_to_token[
|
405
|
+
req_pool_indices[:, None],
|
406
|
+
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
|
407
|
+
]
|
408
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
165
409
|
self.forward_metadata = metadata
|
166
410
|
|
167
411
|
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
@@ -177,12 +421,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
177
421
|
device = seqlens_in_batch.device
|
178
422
|
|
179
423
|
if forward_batch.forward_mode.is_decode_or_idle():
|
180
|
-
|
181
|
-
|
182
|
-
|
424
|
+
if forward_batch.spec_info is not None:
|
425
|
+
# Draft Decode
|
426
|
+
# Here we only support topk = 1 for now.
|
427
|
+
metadata.cache_seqlens_int32 = (
|
428
|
+
seqlens_in_batch + (self.speculative_step_id + 1)
|
429
|
+
).to(torch.int32)
|
430
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
431
|
+
self.speculative_step_id + 1
|
432
|
+
)
|
433
|
+
metadata.cu_seqlens_q = torch.arange(
|
434
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
435
|
+
)
|
436
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
437
|
+
torch.cumsum(
|
438
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
439
|
+
),
|
440
|
+
(1, 0),
|
441
|
+
)
|
442
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
443
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
444
|
+
]
|
445
|
+
else:
|
446
|
+
# Normal Decode
|
447
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
448
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
449
|
+
metadata.cu_seqlens_q = torch.arange(
|
450
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
451
|
+
)
|
452
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
453
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
454
|
+
)
|
455
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
456
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
457
|
+
]
|
458
|
+
elif forward_batch.forward_mode.is_target_verify():
|
459
|
+
# Only support topk = 1 for now.
|
460
|
+
metadata.cache_seqlens_int32 = (
|
461
|
+
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
462
|
+
).to(torch.int32)
|
463
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
464
|
+
metadata.max_seq_len_k = (
|
465
|
+
forward_batch.seq_lens_cpu.max().item()
|
466
|
+
+ self.speculative_num_draft_tokens
|
467
|
+
)
|
468
|
+
metadata.cu_seqlens_q = torch.arange(
|
469
|
+
0,
|
470
|
+
batch_size * self.speculative_num_draft_tokens + 1,
|
471
|
+
self.speculative_num_draft_tokens,
|
472
|
+
dtype=torch.int32,
|
473
|
+
device=device,
|
474
|
+
)
|
475
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
476
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
477
|
+
(1, 0),
|
478
|
+
)
|
183
479
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
184
480
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
185
481
|
]
|
482
|
+
|
186
483
|
else:
|
187
484
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
188
485
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
@@ -193,7 +490,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
193
490
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
194
491
|
]
|
195
492
|
|
196
|
-
if
|
493
|
+
if (
|
494
|
+
any(forward_batch.extend_prefix_lens_cpu)
|
495
|
+
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
496
|
+
):
|
197
497
|
extend_seq_lens = forward_batch.extend_seq_lens
|
198
498
|
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
199
499
|
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
@@ -263,7 +563,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
263
563
|
workspace_buffer=self.workspace_buffer,
|
264
564
|
block_tables=self.forward_metadata.page_table,
|
265
565
|
seq_lens=self.forward_metadata.cache_seqlens_int32,
|
266
|
-
max_seq_len=self.
|
566
|
+
max_seq_len=self.max_context_len,
|
267
567
|
bmm1_scale=bmm1_scale,
|
268
568
|
bmm2_scale=bmm2_scale,
|
269
569
|
window_left=layer.sliding_window_size,
|
@@ -318,7 +618,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
318
618
|
block_tables=self.forward_metadata.page_table,
|
319
619
|
seq_lens=self.forward_metadata.cache_seqlens_int32,
|
320
620
|
max_q_len=self.forward_metadata.max_seq_len_q,
|
321
|
-
max_kv_len=self.
|
621
|
+
max_kv_len=self.max_context_len,
|
322
622
|
bmm1_scale=bmm1_scale,
|
323
623
|
bmm2_scale=bmm2_scale,
|
324
624
|
batch_size=forward_batch.batch_size,
|
@@ -330,3 +630,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
330
630
|
)
|
331
631
|
|
332
632
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
633
|
+
|
634
|
+
|
635
|
+
class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
|
636
|
+
"""Multi-step TRTLLM MHA attention kernel used by EAGLE."""
|
637
|
+
|
638
|
+
def __init__(
|
639
|
+
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
|
640
|
+
):
|
641
|
+
super().__init__(model_runner, topk, speculative_num_steps)
|
642
|
+
for i in range(speculative_num_steps):
|
643
|
+
self.attn_backends[i] = TRTLLMHAAttnBackend(
|
644
|
+
model_runner,
|
645
|
+
skip_prefill=True,
|
646
|
+
kv_indptr_buf=self.kv_indptr[i],
|
647
|
+
kv_last_page_len_buf=self.kv_last_page_len,
|
648
|
+
speculative_step_id=i,
|
649
|
+
)
|
650
|
+
|
651
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
652
|
+
for i in range(self.speculative_num_steps - 1):
|
653
|
+
self.attn_backends[i].init_forward_metadata(forward_batch)
|
654
|
+
|
655
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
656
|
+
for i in range(self.speculative_num_steps):
|
657
|
+
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
658
|
+
|
659
|
+
def init_forward_metadata_capture_cuda_graph(
|
660
|
+
self,
|
661
|
+
forward_batch: ForwardBatch,
|
662
|
+
):
|
663
|
+
assert forward_batch.spec_info is not None
|
664
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
665
|
+
|
666
|
+
for i in range(self.speculative_num_steps - 1):
|
667
|
+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
668
|
+
forward_batch.batch_size,
|
669
|
+
forward_batch.batch_size * self.topk,
|
670
|
+
forward_batch.req_pool_indices,
|
671
|
+
forward_batch.seq_lens,
|
672
|
+
encoder_lens=forward_batch.encoder_lens,
|
673
|
+
forward_mode=ForwardMode.DECODE,
|
674
|
+
spec_info=forward_batch.spec_info,
|
675
|
+
)
|
676
|
+
|
677
|
+
def init_forward_metadata_replay_cuda_graph(
|
678
|
+
self, forward_batch: ForwardBatch, bs: int
|
679
|
+
):
|
680
|
+
assert forward_batch.spec_info is not None
|
681
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
682
|
+
|
683
|
+
for i in range(self.speculative_num_steps - 1):
|
684
|
+
|
685
|
+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
686
|
+
bs,
|
687
|
+
forward_batch.req_pool_indices,
|
688
|
+
forward_batch.seq_lens,
|
689
|
+
forward_batch.seq_lens_sum,
|
690
|
+
encoder_lens=forward_batch.encoder_lens,
|
691
|
+
forward_mode=ForwardMode.DECODE,
|
692
|
+
spec_info=forward_batch.spec_info,
|
693
|
+
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
694
|
+
)
|
@@ -11,7 +11,10 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
11
11
|
import torch
|
12
12
|
import triton
|
13
13
|
|
14
|
-
from sglang.srt.layers.attention.flashinfer_mla_backend import
|
14
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
15
|
+
FlashInferMLAAttnBackend,
|
16
|
+
FlashInferMLAMultiStepDraftBackend,
|
17
|
+
)
|
15
18
|
from sglang.srt.layers.attention.utils import (
|
16
19
|
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
17
20
|
create_flashmla_kv_indices_triton,
|
@@ -39,6 +42,8 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
|
39
42
|
# compute the LCM with other padding constraints.
|
40
43
|
TRTLLM_BLOCK_CONSTRAINT = 128
|
41
44
|
|
45
|
+
global_zero_init_workspace_buffer = None
|
46
|
+
|
42
47
|
|
43
48
|
@dataclass
|
44
49
|
class TRTLLMMLADecodeMetadata:
|
@@ -83,13 +88,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
83
88
|
|
84
89
|
# Workspace allocation
|
85
90
|
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
86
|
-
|
87
|
-
|
88
|
-
|
91
|
+
global global_zero_init_workspace_buffer
|
92
|
+
if global_zero_init_workspace_buffer is None:
|
93
|
+
global_zero_init_workspace_buffer = torch.zeros(
|
94
|
+
self.workspace_size,
|
95
|
+
dtype=torch.uint8,
|
96
|
+
device=model_runner.device,
|
97
|
+
)
|
98
|
+
self.workspace_buffer = global_zero_init_workspace_buffer
|
89
99
|
|
90
100
|
# CUDA graph state
|
91
101
|
self.decode_cuda_graph_metadata = {}
|
92
|
-
self.
|
102
|
+
self.decode_cuda_graph_kv_indices = None
|
93
103
|
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
94
104
|
|
95
105
|
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
@@ -160,15 +170,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
160
170
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
161
171
|
):
|
162
172
|
"""Initialize CUDA graph state for TRTLLM MLA."""
|
173
|
+
|
163
174
|
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
164
175
|
|
165
|
-
self.
|
176
|
+
self.decode_cuda_graph_kv_indices = torch.full(
|
166
177
|
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
167
178
|
)
|
168
|
-
self.
|
179
|
+
self.decode_cuda_graph_workspace = torch.empty(
|
169
180
|
self.workspace_size, dtype=torch.int8, device=self.device
|
170
181
|
)
|
171
182
|
|
183
|
+
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
|
184
|
+
|
172
185
|
def init_forward_metadata_capture_cuda_graph(
|
173
186
|
self,
|
174
187
|
bs: int,
|
@@ -180,8 +193,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
180
193
|
spec_info: Optional[SpecInfo],
|
181
194
|
):
|
182
195
|
"""Initialize metadata for CUDA graph capture."""
|
183
|
-
|
184
|
-
|
196
|
+
|
197
|
+
# Delegate to parent for non-decode modes.
|
198
|
+
if not forward_mode.is_decode_or_idle():
|
185
199
|
return super().init_forward_metadata_capture_cuda_graph(
|
186
200
|
bs,
|
187
201
|
num_tokens,
|
@@ -192,9 +206,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
192
206
|
spec_info,
|
193
207
|
)
|
194
208
|
|
195
|
-
# Custom fast-path for decode/idle
|
209
|
+
# Custom fast-path for decode/idle.
|
196
210
|
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
|
197
|
-
block_kv_indices = self.
|
211
|
+
block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_seqlen_pad]
|
198
212
|
|
199
213
|
create_flashmla_kv_indices_triton[(bs,)](
|
200
214
|
self.req_to_token,
|
@@ -208,7 +222,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
208
222
|
PAGED_SIZE=self.page_size,
|
209
223
|
)
|
210
224
|
|
211
|
-
metadata = TRTLLMMLADecodeMetadata(
|
225
|
+
metadata = TRTLLMMLADecodeMetadata(
|
226
|
+
self.decode_cuda_graph_workspace, block_kv_indices
|
227
|
+
)
|
212
228
|
self.decode_cuda_graph_metadata[bs] = metadata
|
213
229
|
self.forward_metadata = metadata
|
214
230
|
|
@@ -224,8 +240,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
224
240
|
seq_lens_cpu: Optional[torch.Tensor],
|
225
241
|
):
|
226
242
|
"""Replay CUDA graph with new inputs."""
|
227
|
-
# Delegate to parent for non-decode modes
|
228
|
-
if not
|
243
|
+
# Delegate to parent for non-decode modes.
|
244
|
+
if not forward_mode.is_decode_or_idle():
|
229
245
|
return super().init_forward_metadata_replay_cuda_graph(
|
230
246
|
bs,
|
231
247
|
req_pool_indices,
|
@@ -258,11 +274,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
258
274
|
|
259
275
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
260
276
|
"""Initialize the metadata for a forward pass."""
|
261
|
-
# Delegate to parent for non-decode modes
|
262
|
-
if not (
|
263
|
-
forward_batch.forward_mode.is_decode_or_idle()
|
264
|
-
and forward_batch.spec_info is None
|
265
|
-
):
|
277
|
+
# Delegate to parent for non-decode modes.
|
278
|
+
if not forward_batch.forward_mode.is_decode_or_idle():
|
266
279
|
return super().init_forward_metadata(forward_batch)
|
267
280
|
|
268
281
|
bs = forward_batch.batch_size
|
@@ -467,3 +480,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
467
480
|
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
468
481
|
|
469
482
|
return output
|
483
|
+
|
484
|
+
|
485
|
+
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
486
|
+
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
|
487
|
+
|
488
|
+
def __init__(
|
489
|
+
self, model_runner: "ModelRunner", topk: int, speculative_num_steps: int
|
490
|
+
):
|
491
|
+
super().__init__(model_runner, topk, speculative_num_steps)
|
492
|
+
|
493
|
+
for i in range(self.speculative_num_steps):
|
494
|
+
self.attn_backends[i] = TRTLLMMLABackend(
|
495
|
+
model_runner,
|
496
|
+
skip_prefill=True,
|
497
|
+
kv_indptr_buf=self.kv_indptr[i],
|
498
|
+
q_indptr_decode_buf=self.q_indptr_decode,
|
499
|
+
)
|