sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +375 -51
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -10,13 +10,18 @@ from typing import TYPE_CHECKING, Optional
|
|
10
10
|
|
11
11
|
import torch
|
12
12
|
|
13
|
-
from sglang.srt.layers.attention.flashinfer_backend import
|
13
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
14
|
+
FlashInferAttnBackend,
|
15
|
+
FlashInferMultiStepDraftBackend,
|
16
|
+
)
|
14
17
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
15
18
|
from sglang.srt.utils import is_flashinfer_available
|
16
19
|
|
17
20
|
if is_flashinfer_available():
|
18
21
|
import flashinfer
|
19
22
|
|
23
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
24
|
+
|
20
25
|
if TYPE_CHECKING:
|
21
26
|
from sglang.srt.layers.radix_attention import RadixAttention
|
22
27
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -55,9 +60,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
55
60
|
model_runner: ModelRunner,
|
56
61
|
skip_prefill: bool = False,
|
57
62
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
58
|
-
|
63
|
+
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
64
|
+
speculative_step_id: int = 0,
|
59
65
|
):
|
60
|
-
super().__init__(
|
66
|
+
super().__init__(
|
67
|
+
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
|
68
|
+
)
|
61
69
|
|
62
70
|
config = model_runner.model_config
|
63
71
|
|
@@ -87,6 +95,16 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
87
95
|
# CUDA graph state
|
88
96
|
self.decode_cuda_graph_metadata = {}
|
89
97
|
|
98
|
+
# Speculative decoding
|
99
|
+
# Only support topk <= 1 for now.
|
100
|
+
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
101
|
+
self.speculative_step_id = speculative_step_id
|
102
|
+
self.target_verify_metadata = {}
|
103
|
+
|
104
|
+
self.speculative_num_draft_tokens = (
|
105
|
+
model_runner.server_args.speculative_num_draft_tokens
|
106
|
+
)
|
107
|
+
|
90
108
|
# Forward metadata
|
91
109
|
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
|
92
110
|
|
@@ -97,11 +115,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
97
115
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
98
116
|
):
|
99
117
|
"""Initialize CUDA graph state for TRTLLM MHA."""
|
118
|
+
max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
|
100
119
|
self.decode_cuda_graph_metadata = {
|
101
120
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
102
121
|
"page_table": torch.zeros(
|
103
122
|
max_bs,
|
104
|
-
|
123
|
+
max_num_pages,
|
105
124
|
dtype=torch.int32,
|
106
125
|
device=self.device,
|
107
126
|
),
|
@@ -110,6 +129,70 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
110
129
|
),
|
111
130
|
}
|
112
131
|
|
132
|
+
if (
|
133
|
+
self.speculative_num_draft_tokens is not None
|
134
|
+
and self.speculative_num_draft_tokens > 0
|
135
|
+
):
|
136
|
+
self.decode_cuda_graph_metadata["cu_seqlens_q"] = torch.arange(
|
137
|
+
0, max_bs + 1, dtype=torch.int32, device=self.device
|
138
|
+
)
|
139
|
+
self.decode_cuda_graph_metadata["cu_seqlens_k"] = torch.zeros(
|
140
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
141
|
+
)
|
142
|
+
self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
|
143
|
+
max_bs,
|
144
|
+
max_num_pages,
|
145
|
+
dtype=torch.int32,
|
146
|
+
device=self.device,
|
147
|
+
)
|
148
|
+
self.target_verify_metadata = {
|
149
|
+
"cache_seqlens": torch.zeros(
|
150
|
+
max_bs, dtype=torch.int32, device=self.device
|
151
|
+
),
|
152
|
+
"cu_seqlens_q": torch.arange(
|
153
|
+
0,
|
154
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
155
|
+
step=self.speculative_num_draft_tokens,
|
156
|
+
dtype=torch.int32,
|
157
|
+
device=self.device,
|
158
|
+
),
|
159
|
+
"cu_seqlens_k": torch.zeros(
|
160
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
161
|
+
),
|
162
|
+
"page_table": torch.zeros(
|
163
|
+
max_bs,
|
164
|
+
max_num_pages,
|
165
|
+
dtype=torch.int32,
|
166
|
+
device=self.device,
|
167
|
+
),
|
168
|
+
"strided_indices": torch.arange(
|
169
|
+
0, self.max_context_len, self.page_size, device=self.device
|
170
|
+
),
|
171
|
+
}
|
172
|
+
|
173
|
+
self.draft_extend_metadata = {
|
174
|
+
"cache_seqlens": torch.zeros(
|
175
|
+
max_bs, dtype=torch.int32, device=self.device
|
176
|
+
),
|
177
|
+
"cu_seqlens_q": torch.zeros(
|
178
|
+
max_bs + 1,
|
179
|
+
dtype=torch.int32,
|
180
|
+
device=self.device,
|
181
|
+
),
|
182
|
+
"cu_seqlens_k": torch.zeros(
|
183
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
184
|
+
),
|
185
|
+
"page_table": torch.zeros(
|
186
|
+
max_bs,
|
187
|
+
max_num_pages,
|
188
|
+
dtype=torch.int32,
|
189
|
+
device=self.device,
|
190
|
+
),
|
191
|
+
"strided_indices": torch.arange(
|
192
|
+
0, self.max_context_len, self.page_size, device=self.device
|
193
|
+
),
|
194
|
+
}
|
195
|
+
|
113
196
|
def init_forward_metadata_capture_cuda_graph(
|
114
197
|
self,
|
115
198
|
bs: int,
|
@@ -122,16 +205,105 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
122
205
|
):
|
123
206
|
"""Initialize metadata for CUDA graph capture."""
|
124
207
|
metadata = TRTLLMMHAMetadata()
|
208
|
+
device = seq_lens.device
|
125
209
|
|
126
|
-
|
127
|
-
|
210
|
+
if forward_mode.is_decode_or_idle():
|
211
|
+
if spec_info is not None:
|
212
|
+
# Draft Decode
|
213
|
+
# Here we only support topk = 1 for now.
|
214
|
+
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
215
|
+
"cache_seqlens"
|
216
|
+
][:bs]
|
217
|
+
metadata.max_seq_len_k = seq_lens.max().item() + (
|
218
|
+
self.speculative_step_id + 1
|
219
|
+
)
|
220
|
+
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
|
221
|
+
: bs + 1
|
222
|
+
]
|
223
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
224
|
+
torch.cumsum(
|
225
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
226
|
+
),
|
227
|
+
(1, 0),
|
228
|
+
)
|
229
|
+
metadata.page_table = self.decode_cuda_graph_metadata[
|
230
|
+
"page_table_draft_decode"
|
231
|
+
][:bs, :]
|
232
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
233
|
+
else:
|
234
|
+
# Normal Decode
|
235
|
+
# Get sequence information
|
236
|
+
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
|
237
|
+
batch_size = len(seq_lens)
|
238
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
239
|
+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
240
|
+
)
|
128
241
|
|
129
|
-
|
130
|
-
|
242
|
+
# Precompute maximum sequence length
|
243
|
+
metadata.max_seq_len_k = seq_lens.max().item()
|
244
|
+
# Precompute cumulative sequence lengths
|
245
|
+
metadata.cu_seqlens_q = torch.arange(
|
246
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
247
|
+
)
|
248
|
+
# Precompute page table
|
249
|
+
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
250
|
+
:bs, :
|
251
|
+
]
|
252
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
253
|
+
elif forward_mode.is_target_verify():
|
254
|
+
# Target Verify
|
255
|
+
# Here we only support topk = 1 for now.
|
256
|
+
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
|
257
|
+
:bs
|
258
|
+
]
|
259
|
+
metadata.cache_seqlens_int32.copy_(
|
260
|
+
(seq_lens + self.speculative_num_draft_tokens)
|
261
|
+
)
|
131
262
|
|
132
|
-
|
133
|
-
|
134
|
-
|
263
|
+
metadata.cu_seqlens_q = torch.arange(
|
264
|
+
0,
|
265
|
+
bs * self.speculative_num_draft_tokens + 1,
|
266
|
+
self.speculative_num_draft_tokens,
|
267
|
+
dtype=torch.int32,
|
268
|
+
device=device,
|
269
|
+
)
|
270
|
+
|
271
|
+
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
|
272
|
+
: (bs + 1)
|
273
|
+
]
|
274
|
+
|
275
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
276
|
+
metadata.max_seq_len_k = (
|
277
|
+
seq_lens.max().item() + self.speculative_num_draft_tokens
|
278
|
+
)
|
279
|
+
|
280
|
+
metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
|
281
|
+
|
282
|
+
self.target_verify_metadata[bs] = metadata
|
283
|
+
elif forward_mode.is_draft_extend():
|
284
|
+
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
285
|
+
:bs
|
286
|
+
]
|
287
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
288
|
+
num_tokens_per_bs = num_tokens // bs
|
289
|
+
metadata.cu_seqlens_q = torch.arange(
|
290
|
+
0,
|
291
|
+
bs * num_tokens_per_bs + 1,
|
292
|
+
num_tokens_per_bs,
|
293
|
+
dtype=torch.int32,
|
294
|
+
device=device,
|
295
|
+
)
|
296
|
+
|
297
|
+
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
|
298
|
+
: (bs + 1)
|
299
|
+
]
|
300
|
+
num_tokens_per_bs = num_tokens // bs
|
301
|
+
metadata.max_seq_len_q = num_tokens_per_bs
|
302
|
+
metadata.max_seq_len_k = seq_lens.max().item()
|
303
|
+
|
304
|
+
metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
|
305
|
+
|
306
|
+
self.draft_extend_metadata[bs] = metadata
|
135
307
|
self.forward_metadata = metadata
|
136
308
|
|
137
309
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -149,21 +321,91 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
149
321
|
seq_lens = seq_lens[:bs]
|
150
322
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
151
323
|
req_pool_indices = req_pool_indices[:bs]
|
152
|
-
device = seq_lens.device
|
153
324
|
metadata = None
|
325
|
+
if forward_mode.is_decode_or_idle():
|
326
|
+
if spec_info is not None:
|
327
|
+
# Draft Decode
|
328
|
+
# Here we only support topk = 1 for now.
|
329
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
330
|
+
max_len = seq_lens_cpu.max().item()
|
331
|
+
metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
|
332
|
+
|
333
|
+
max_seq_pages = (
|
334
|
+
metadata.max_seq_len_k + self.page_size - 1
|
335
|
+
) // self.page_size
|
336
|
+
|
337
|
+
metadata.cache_seqlens_int32.copy_(
|
338
|
+
seq_lens + self.speculative_step_id + 1
|
339
|
+
)
|
340
|
+
else:
|
341
|
+
# Normal Decode
|
342
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
343
|
+
max_len = seq_lens_cpu.max().item()
|
344
|
+
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
345
|
+
metadata.max_seq_len_k = max_len
|
346
|
+
|
347
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
348
|
+
|
349
|
+
metadata.cu_seqlens_k[1:].copy_(
|
350
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
351
|
+
)
|
352
|
+
page_indices = self.req_to_token[
|
353
|
+
req_pool_indices[:, None],
|
354
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
|
355
|
+
None, :
|
356
|
+
],
|
357
|
+
]
|
358
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
359
|
+
elif forward_mode.is_target_verify():
|
360
|
+
# Here we only support topk = 1 for now.
|
361
|
+
metadata = self.target_verify_metadata[bs]
|
362
|
+
metadata.cache_seqlens_int32.copy_(
|
363
|
+
(seq_lens + self.speculative_num_draft_tokens)
|
364
|
+
)
|
154
365
|
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
self.
|
165
|
-
|
166
|
-
|
366
|
+
metadata.max_seq_len_k = (
|
367
|
+
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
368
|
+
)
|
369
|
+
max_len = seq_lens_cpu.max().item()
|
370
|
+
metadata.cu_seqlens_k[1:].copy_(
|
371
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
372
|
+
)
|
373
|
+
max_seq_pages = (
|
374
|
+
metadata.max_seq_len_k + self.page_size - 1
|
375
|
+
) // self.page_size
|
376
|
+
page_indices = self.req_to_token[
|
377
|
+
req_pool_indices[:, None],
|
378
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
379
|
+
]
|
380
|
+
page_indices //= self.page_size
|
381
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
382
|
+
elif forward_mode.is_draft_extend():
|
383
|
+
metadata = self.draft_extend_metadata[bs]
|
384
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
385
|
+
|
386
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
387
|
+
max_len = seq_lens_cpu.max().item()
|
388
|
+
metadata.cu_seqlens_k[1:].copy_(
|
389
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
390
|
+
)
|
391
|
+
accept_length = spec_info.accept_length[:bs]
|
392
|
+
if spec_info.accept_length_cpu:
|
393
|
+
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
|
394
|
+
else:
|
395
|
+
metadata.max_seq_len_q = 1
|
396
|
+
|
397
|
+
metadata.cu_seqlens_q[1:].copy_(
|
398
|
+
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
399
|
+
)
|
400
|
+
|
401
|
+
max_seq_pages = (
|
402
|
+
metadata.max_seq_len_k + self.page_size - 1
|
403
|
+
) // self.page_size
|
404
|
+
page_indices = self.req_to_token[
|
405
|
+
req_pool_indices[:, None],
|
406
|
+
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
|
407
|
+
]
|
408
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
167
409
|
self.forward_metadata = metadata
|
168
410
|
|
169
411
|
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
@@ -179,12 +421,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
179
421
|
device = seqlens_in_batch.device
|
180
422
|
|
181
423
|
if forward_batch.forward_mode.is_decode_or_idle():
|
182
|
-
|
183
|
-
|
184
|
-
|
424
|
+
if forward_batch.spec_info is not None:
|
425
|
+
# Draft Decode
|
426
|
+
# Here we only support topk = 1 for now.
|
427
|
+
metadata.cache_seqlens_int32 = (
|
428
|
+
seqlens_in_batch + (self.speculative_step_id + 1)
|
429
|
+
).to(torch.int32)
|
430
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
431
|
+
self.speculative_step_id + 1
|
432
|
+
)
|
433
|
+
metadata.cu_seqlens_q = torch.arange(
|
434
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
435
|
+
)
|
436
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
437
|
+
torch.cumsum(
|
438
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
439
|
+
),
|
440
|
+
(1, 0),
|
441
|
+
)
|
442
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
443
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
444
|
+
]
|
445
|
+
else:
|
446
|
+
# Normal Decode
|
447
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
448
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
449
|
+
metadata.cu_seqlens_q = torch.arange(
|
450
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
451
|
+
)
|
452
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
453
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
454
|
+
)
|
455
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
456
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
457
|
+
]
|
458
|
+
elif forward_batch.forward_mode.is_target_verify():
|
459
|
+
# Only support topk = 1 for now.
|
460
|
+
metadata.cache_seqlens_int32 = (
|
461
|
+
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
462
|
+
).to(torch.int32)
|
463
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
464
|
+
metadata.max_seq_len_k = (
|
465
|
+
forward_batch.seq_lens_cpu.max().item()
|
466
|
+
+ self.speculative_num_draft_tokens
|
467
|
+
)
|
468
|
+
metadata.cu_seqlens_q = torch.arange(
|
469
|
+
0,
|
470
|
+
batch_size * self.speculative_num_draft_tokens + 1,
|
471
|
+
self.speculative_num_draft_tokens,
|
472
|
+
dtype=torch.int32,
|
473
|
+
device=device,
|
474
|
+
)
|
475
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
476
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
477
|
+
(1, 0),
|
478
|
+
)
|
185
479
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
186
480
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
187
481
|
]
|
482
|
+
|
188
483
|
else:
|
189
484
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
190
485
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
@@ -195,7 +490,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
195
490
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
196
491
|
]
|
197
492
|
|
198
|
-
if
|
493
|
+
if (
|
494
|
+
any(forward_batch.extend_prefix_lens_cpu)
|
495
|
+
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
496
|
+
):
|
199
497
|
extend_seq_lens = forward_batch.extend_seq_lens
|
200
498
|
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
201
499
|
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
@@ -265,7 +563,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
265
563
|
workspace_buffer=self.workspace_buffer,
|
266
564
|
block_tables=self.forward_metadata.page_table,
|
267
565
|
seq_lens=self.forward_metadata.cache_seqlens_int32,
|
268
|
-
max_seq_len=self.
|
566
|
+
max_seq_len=self.max_context_len,
|
269
567
|
bmm1_scale=bmm1_scale,
|
270
568
|
bmm2_scale=bmm2_scale,
|
271
569
|
window_left=layer.sliding_window_size,
|
@@ -320,7 +618,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
320
618
|
block_tables=self.forward_metadata.page_table,
|
321
619
|
seq_lens=self.forward_metadata.cache_seqlens_int32,
|
322
620
|
max_q_len=self.forward_metadata.max_seq_len_q,
|
323
|
-
max_kv_len=self.
|
621
|
+
max_kv_len=self.max_context_len,
|
324
622
|
bmm1_scale=bmm1_scale,
|
325
623
|
bmm2_scale=bmm2_scale,
|
326
624
|
batch_size=forward_batch.batch_size,
|
@@ -332,3 +630,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
332
630
|
)
|
333
631
|
|
334
632
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
633
|
+
|
634
|
+
|
635
|
+
class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
|
636
|
+
"""Multi-step TRTLLM MHA attention kernel used by EAGLE."""
|
637
|
+
|
638
|
+
def __init__(
|
639
|
+
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
|
640
|
+
):
|
641
|
+
super().__init__(model_runner, topk, speculative_num_steps)
|
642
|
+
for i in range(speculative_num_steps):
|
643
|
+
self.attn_backends[i] = TRTLLMHAAttnBackend(
|
644
|
+
model_runner,
|
645
|
+
skip_prefill=True,
|
646
|
+
kv_indptr_buf=self.kv_indptr[i],
|
647
|
+
kv_last_page_len_buf=self.kv_last_page_len,
|
648
|
+
speculative_step_id=i,
|
649
|
+
)
|
650
|
+
|
651
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
652
|
+
for i in range(self.speculative_num_steps - 1):
|
653
|
+
self.attn_backends[i].init_forward_metadata(forward_batch)
|
654
|
+
|
655
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
656
|
+
for i in range(self.speculative_num_steps):
|
657
|
+
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
658
|
+
|
659
|
+
def init_forward_metadata_capture_cuda_graph(
|
660
|
+
self,
|
661
|
+
forward_batch: ForwardBatch,
|
662
|
+
):
|
663
|
+
assert forward_batch.spec_info is not None
|
664
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
665
|
+
|
666
|
+
for i in range(self.speculative_num_steps - 1):
|
667
|
+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
668
|
+
forward_batch.batch_size,
|
669
|
+
forward_batch.batch_size * self.topk,
|
670
|
+
forward_batch.req_pool_indices,
|
671
|
+
forward_batch.seq_lens,
|
672
|
+
encoder_lens=forward_batch.encoder_lens,
|
673
|
+
forward_mode=ForwardMode.DECODE,
|
674
|
+
spec_info=forward_batch.spec_info,
|
675
|
+
)
|
676
|
+
|
677
|
+
def init_forward_metadata_replay_cuda_graph(
|
678
|
+
self, forward_batch: ForwardBatch, bs: int
|
679
|
+
):
|
680
|
+
assert forward_batch.spec_info is not None
|
681
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
682
|
+
|
683
|
+
for i in range(self.speculative_num_steps - 1):
|
684
|
+
|
685
|
+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
686
|
+
bs,
|
687
|
+
forward_batch.req_pool_indices,
|
688
|
+
forward_batch.seq_lens,
|
689
|
+
forward_batch.seq_lens_sum,
|
690
|
+
encoder_lens=forward_batch.encoder_lens,
|
691
|
+
forward_mode=ForwardMode.DECODE,
|
692
|
+
spec_info=forward_batch.spec_info,
|
693
|
+
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
694
|
+
)
|
@@ -11,7 +11,10 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
11
11
|
import torch
|
12
12
|
import triton
|
13
13
|
|
14
|
-
from sglang.srt.layers.attention.flashinfer_mla_backend import
|
14
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
15
|
+
FlashInferMLAAttnBackend,
|
16
|
+
FlashInferMLAMultiStepDraftBackend,
|
17
|
+
)
|
15
18
|
from sglang.srt.layers.attention.utils import (
|
16
19
|
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
17
20
|
create_flashmla_kv_indices_triton,
|
@@ -96,7 +99,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
96
99
|
|
97
100
|
# CUDA graph state
|
98
101
|
self.decode_cuda_graph_metadata = {}
|
99
|
-
self.
|
102
|
+
self.decode_cuda_graph_kv_indices = None
|
100
103
|
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
101
104
|
|
102
105
|
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
@@ -167,15 +170,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
167
170
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
168
171
|
):
|
169
172
|
"""Initialize CUDA graph state for TRTLLM MLA."""
|
173
|
+
|
170
174
|
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
171
175
|
|
172
|
-
self.
|
176
|
+
self.decode_cuda_graph_kv_indices = torch.full(
|
173
177
|
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
174
178
|
)
|
175
|
-
self.
|
179
|
+
self.decode_cuda_graph_workspace = torch.empty(
|
176
180
|
self.workspace_size, dtype=torch.int8, device=self.device
|
177
181
|
)
|
178
182
|
|
183
|
+
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
|
184
|
+
|
179
185
|
def init_forward_metadata_capture_cuda_graph(
|
180
186
|
self,
|
181
187
|
bs: int,
|
@@ -187,8 +193,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
187
193
|
spec_info: Optional[SpecInfo],
|
188
194
|
):
|
189
195
|
"""Initialize metadata for CUDA graph capture."""
|
190
|
-
|
191
|
-
|
196
|
+
|
197
|
+
# Delegate to parent for non-decode modes.
|
198
|
+
if not forward_mode.is_decode_or_idle():
|
192
199
|
return super().init_forward_metadata_capture_cuda_graph(
|
193
200
|
bs,
|
194
201
|
num_tokens,
|
@@ -199,9 +206,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
199
206
|
spec_info,
|
200
207
|
)
|
201
208
|
|
202
|
-
# Custom fast-path for decode/idle
|
209
|
+
# Custom fast-path for decode/idle.
|
203
210
|
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
|
204
|
-
block_kv_indices = self.
|
211
|
+
block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_seqlen_pad]
|
205
212
|
|
206
213
|
create_flashmla_kv_indices_triton[(bs,)](
|
207
214
|
self.req_to_token,
|
@@ -215,7 +222,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
215
222
|
PAGED_SIZE=self.page_size,
|
216
223
|
)
|
217
224
|
|
218
|
-
metadata = TRTLLMMLADecodeMetadata(
|
225
|
+
metadata = TRTLLMMLADecodeMetadata(
|
226
|
+
self.decode_cuda_graph_workspace, block_kv_indices
|
227
|
+
)
|
219
228
|
self.decode_cuda_graph_metadata[bs] = metadata
|
220
229
|
self.forward_metadata = metadata
|
221
230
|
|
@@ -231,8 +240,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
231
240
|
seq_lens_cpu: Optional[torch.Tensor],
|
232
241
|
):
|
233
242
|
"""Replay CUDA graph with new inputs."""
|
234
|
-
# Delegate to parent for non-decode modes
|
235
|
-
if not
|
243
|
+
# Delegate to parent for non-decode modes.
|
244
|
+
if not forward_mode.is_decode_or_idle():
|
236
245
|
return super().init_forward_metadata_replay_cuda_graph(
|
237
246
|
bs,
|
238
247
|
req_pool_indices,
|
@@ -265,11 +274,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
265
274
|
|
266
275
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
267
276
|
"""Initialize the metadata for a forward pass."""
|
268
|
-
# Delegate to parent for non-decode modes
|
269
|
-
if not (
|
270
|
-
forward_batch.forward_mode.is_decode_or_idle()
|
271
|
-
and forward_batch.spec_info is None
|
272
|
-
):
|
277
|
+
# Delegate to parent for non-decode modes.
|
278
|
+
if not forward_batch.forward_mode.is_decode_or_idle():
|
273
279
|
return super().init_forward_metadata(forward_batch)
|
274
280
|
|
275
281
|
bs = forward_batch.batch_size
|
@@ -474,3 +480,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
474
480
|
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
475
481
|
|
476
482
|
return output
|
483
|
+
|
484
|
+
|
485
|
+
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
486
|
+
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
|
487
|
+
|
488
|
+
def __init__(
|
489
|
+
self, model_runner: "ModelRunner", topk: int, speculative_num_steps: int
|
490
|
+
):
|
491
|
+
super().__init__(model_runner, topk, speculative_num_steps)
|
492
|
+
|
493
|
+
for i in range(self.speculative_num_steps):
|
494
|
+
self.attn_backends[i] = TRTLLMMLABackend(
|
495
|
+
model_runner,
|
496
|
+
skip_prefill=True,
|
497
|
+
kv_indptr_buf=self.kv_indptr[i],
|
498
|
+
q_indptr_decode_buf=self.q_indptr_decode,
|
499
|
+
)
|