sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +375 -51
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -10,13 +10,18 @@ from typing import TYPE_CHECKING, Optional
10
10
 
11
11
  import torch
12
12
 
13
- from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
13
+ from sglang.srt.layers.attention.flashinfer_backend import (
14
+ FlashInferAttnBackend,
15
+ FlashInferMultiStepDraftBackend,
16
+ )
14
17
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
18
  from sglang.srt.utils import is_flashinfer_available
16
19
 
17
20
  if is_flashinfer_available():
18
21
  import flashinfer
19
22
 
23
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput
24
+
20
25
  if TYPE_CHECKING:
21
26
  from sglang.srt.layers.radix_attention import RadixAttention
22
27
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -55,9 +60,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
55
60
  model_runner: ModelRunner,
56
61
  skip_prefill: bool = False,
57
62
  kv_indptr_buf: Optional[torch.Tensor] = None,
58
- q_indptr_decode_buf: Optional[torch.Tensor] = None,
63
+ kv_last_page_len_buf: Optional[torch.Tensor] = None,
64
+ speculative_step_id: int = 0,
59
65
  ):
60
- super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
66
+ super().__init__(
67
+ model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
68
+ )
61
69
 
62
70
  config = model_runner.model_config
63
71
 
@@ -87,6 +95,16 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
87
95
  # CUDA graph state
88
96
  self.decode_cuda_graph_metadata = {}
89
97
 
98
+ # Speculative decoding
99
+ # Only support topk <= 1 for now.
100
+ self.topk = model_runner.server_args.speculative_eagle_topk or 0
101
+ self.speculative_step_id = speculative_step_id
102
+ self.target_verify_metadata = {}
103
+
104
+ self.speculative_num_draft_tokens = (
105
+ model_runner.server_args.speculative_num_draft_tokens
106
+ )
107
+
90
108
  # Forward metadata
91
109
  self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
92
110
 
@@ -97,11 +115,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
97
115
  kv_indices_buf: Optional[torch.Tensor] = None,
98
116
  ):
99
117
  """Initialize CUDA graph state for TRTLLM MHA."""
118
+ max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
100
119
  self.decode_cuda_graph_metadata = {
101
120
  "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
102
121
  "page_table": torch.zeros(
103
122
  max_bs,
104
- (self.max_context_len + self.page_size - 1) // self.page_size,
123
+ max_num_pages,
105
124
  dtype=torch.int32,
106
125
  device=self.device,
107
126
  ),
@@ -110,6 +129,70 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
110
129
  ),
111
130
  }
112
131
 
132
+ if (
133
+ self.speculative_num_draft_tokens is not None
134
+ and self.speculative_num_draft_tokens > 0
135
+ ):
136
+ self.decode_cuda_graph_metadata["cu_seqlens_q"] = torch.arange(
137
+ 0, max_bs + 1, dtype=torch.int32, device=self.device
138
+ )
139
+ self.decode_cuda_graph_metadata["cu_seqlens_k"] = torch.zeros(
140
+ max_bs + 1, dtype=torch.int32, device=self.device
141
+ )
142
+ self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
143
+ max_bs,
144
+ max_num_pages,
145
+ dtype=torch.int32,
146
+ device=self.device,
147
+ )
148
+ self.target_verify_metadata = {
149
+ "cache_seqlens": torch.zeros(
150
+ max_bs, dtype=torch.int32, device=self.device
151
+ ),
152
+ "cu_seqlens_q": torch.arange(
153
+ 0,
154
+ max_bs * self.speculative_num_draft_tokens + 1,
155
+ step=self.speculative_num_draft_tokens,
156
+ dtype=torch.int32,
157
+ device=self.device,
158
+ ),
159
+ "cu_seqlens_k": torch.zeros(
160
+ max_bs + 1, dtype=torch.int32, device=self.device
161
+ ),
162
+ "page_table": torch.zeros(
163
+ max_bs,
164
+ max_num_pages,
165
+ dtype=torch.int32,
166
+ device=self.device,
167
+ ),
168
+ "strided_indices": torch.arange(
169
+ 0, self.max_context_len, self.page_size, device=self.device
170
+ ),
171
+ }
172
+
173
+ self.draft_extend_metadata = {
174
+ "cache_seqlens": torch.zeros(
175
+ max_bs, dtype=torch.int32, device=self.device
176
+ ),
177
+ "cu_seqlens_q": torch.zeros(
178
+ max_bs + 1,
179
+ dtype=torch.int32,
180
+ device=self.device,
181
+ ),
182
+ "cu_seqlens_k": torch.zeros(
183
+ max_bs + 1, dtype=torch.int32, device=self.device
184
+ ),
185
+ "page_table": torch.zeros(
186
+ max_bs,
187
+ max_num_pages,
188
+ dtype=torch.int32,
189
+ device=self.device,
190
+ ),
191
+ "strided_indices": torch.arange(
192
+ 0, self.max_context_len, self.page_size, device=self.device
193
+ ),
194
+ }
195
+
113
196
  def init_forward_metadata_capture_cuda_graph(
114
197
  self,
115
198
  bs: int,
@@ -122,16 +205,105 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
122
205
  ):
123
206
  """Initialize metadata for CUDA graph capture."""
124
207
  metadata = TRTLLMMHAMetadata()
208
+ device = seq_lens.device
125
209
 
126
- # Get sequence information
127
- metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
210
+ if forward_mode.is_decode_or_idle():
211
+ if spec_info is not None:
212
+ # Draft Decode
213
+ # Here we only support topk = 1 for now.
214
+ metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
215
+ "cache_seqlens"
216
+ ][:bs]
217
+ metadata.max_seq_len_k = seq_lens.max().item() + (
218
+ self.speculative_step_id + 1
219
+ )
220
+ metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
221
+ : bs + 1
222
+ ]
223
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
224
+ torch.cumsum(
225
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
226
+ ),
227
+ (1, 0),
228
+ )
229
+ metadata.page_table = self.decode_cuda_graph_metadata[
230
+ "page_table_draft_decode"
231
+ ][:bs, :]
232
+ self.decode_cuda_graph_metadata[bs] = metadata
233
+ else:
234
+ # Normal Decode
235
+ # Get sequence information
236
+ metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
237
+ batch_size = len(seq_lens)
238
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
239
+ torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
240
+ )
128
241
 
129
- # Precompute maximum sequence length
130
- metadata.max_seq_len_k = self.max_context_len
242
+ # Precompute maximum sequence length
243
+ metadata.max_seq_len_k = seq_lens.max().item()
244
+ # Precompute cumulative sequence lengths
245
+ metadata.cu_seqlens_q = torch.arange(
246
+ 0, batch_size + 1, dtype=torch.int32, device=device
247
+ )
248
+ # Precompute page table
249
+ metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
250
+ :bs, :
251
+ ]
252
+ self.decode_cuda_graph_metadata[bs] = metadata
253
+ elif forward_mode.is_target_verify():
254
+ # Target Verify
255
+ # Here we only support topk = 1 for now.
256
+ metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
257
+ :bs
258
+ ]
259
+ metadata.cache_seqlens_int32.copy_(
260
+ (seq_lens + self.speculative_num_draft_tokens)
261
+ )
131
262
 
132
- # Precompute page table
133
- metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
134
- self.decode_cuda_graph_metadata[bs] = metadata
263
+ metadata.cu_seqlens_q = torch.arange(
264
+ 0,
265
+ bs * self.speculative_num_draft_tokens + 1,
266
+ self.speculative_num_draft_tokens,
267
+ dtype=torch.int32,
268
+ device=device,
269
+ )
270
+
271
+ metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
272
+ : (bs + 1)
273
+ ]
274
+
275
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
276
+ metadata.max_seq_len_k = (
277
+ seq_lens.max().item() + self.speculative_num_draft_tokens
278
+ )
279
+
280
+ metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
281
+
282
+ self.target_verify_metadata[bs] = metadata
283
+ elif forward_mode.is_draft_extend():
284
+ metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
285
+ :bs
286
+ ]
287
+ metadata.cache_seqlens_int32.copy_(seq_lens)
288
+ num_tokens_per_bs = num_tokens // bs
289
+ metadata.cu_seqlens_q = torch.arange(
290
+ 0,
291
+ bs * num_tokens_per_bs + 1,
292
+ num_tokens_per_bs,
293
+ dtype=torch.int32,
294
+ device=device,
295
+ )
296
+
297
+ metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
298
+ : (bs + 1)
299
+ ]
300
+ num_tokens_per_bs = num_tokens // bs
301
+ metadata.max_seq_len_q = num_tokens_per_bs
302
+ metadata.max_seq_len_k = seq_lens.max().item()
303
+
304
+ metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
305
+
306
+ self.draft_extend_metadata[bs] = metadata
135
307
  self.forward_metadata = metadata
136
308
 
137
309
  def init_forward_metadata_replay_cuda_graph(
@@ -149,21 +321,91 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
149
321
  seq_lens = seq_lens[:bs]
150
322
  seq_lens_cpu = seq_lens_cpu[:bs]
151
323
  req_pool_indices = req_pool_indices[:bs]
152
- device = seq_lens.device
153
324
  metadata = None
325
+ if forward_mode.is_decode_or_idle():
326
+ if spec_info is not None:
327
+ # Draft Decode
328
+ # Here we only support topk = 1 for now.
329
+ metadata = self.decode_cuda_graph_metadata[bs]
330
+ max_len = seq_lens_cpu.max().item()
331
+ metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
332
+
333
+ max_seq_pages = (
334
+ metadata.max_seq_len_k + self.page_size - 1
335
+ ) // self.page_size
336
+
337
+ metadata.cache_seqlens_int32.copy_(
338
+ seq_lens + self.speculative_step_id + 1
339
+ )
340
+ else:
341
+ # Normal Decode
342
+ metadata = self.decode_cuda_graph_metadata[bs]
343
+ max_len = seq_lens_cpu.max().item()
344
+ max_seq_pages = (max_len + self.page_size - 1) // self.page_size
345
+ metadata.max_seq_len_k = max_len
346
+
347
+ metadata.cache_seqlens_int32.copy_(seq_lens)
348
+
349
+ metadata.cu_seqlens_k[1:].copy_(
350
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
351
+ )
352
+ page_indices = self.req_to_token[
353
+ req_pool_indices[:, None],
354
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
355
+ None, :
356
+ ],
357
+ ]
358
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
359
+ elif forward_mode.is_target_verify():
360
+ # Here we only support topk = 1 for now.
361
+ metadata = self.target_verify_metadata[bs]
362
+ metadata.cache_seqlens_int32.copy_(
363
+ (seq_lens + self.speculative_num_draft_tokens)
364
+ )
154
365
 
155
- # Normal Decode
156
- metadata = self.decode_cuda_graph_metadata[bs]
157
- max_len = seq_lens_cpu.max().item()
158
- max_seq_pages = (max_len + self.page_size - 1) // self.page_size
159
- metadata.max_seq_len_k = self.max_context_len
160
-
161
- metadata.cache_seqlens_int32.copy_(seq_lens)
162
- page_indices = self.req_to_token[
163
- req_pool_indices[:, None],
164
- self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][None, :],
165
- ]
166
- metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
366
+ metadata.max_seq_len_k = (
367
+ seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
368
+ )
369
+ max_len = seq_lens_cpu.max().item()
370
+ metadata.cu_seqlens_k[1:].copy_(
371
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
372
+ )
373
+ max_seq_pages = (
374
+ metadata.max_seq_len_k + self.page_size - 1
375
+ ) // self.page_size
376
+ page_indices = self.req_to_token[
377
+ req_pool_indices[:, None],
378
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
379
+ ]
380
+ page_indices //= self.page_size
381
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices)
382
+ elif forward_mode.is_draft_extend():
383
+ metadata = self.draft_extend_metadata[bs]
384
+ metadata.cache_seqlens_int32.copy_(seq_lens)
385
+
386
+ metadata.max_seq_len_k = seq_lens_cpu.max().item()
387
+ max_len = seq_lens_cpu.max().item()
388
+ metadata.cu_seqlens_k[1:].copy_(
389
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
390
+ )
391
+ accept_length = spec_info.accept_length[:bs]
392
+ if spec_info.accept_length_cpu:
393
+ metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
394
+ else:
395
+ metadata.max_seq_len_q = 1
396
+
397
+ metadata.cu_seqlens_q[1:].copy_(
398
+ torch.cumsum(accept_length, dim=0, dtype=torch.int32)
399
+ )
400
+
401
+ max_seq_pages = (
402
+ metadata.max_seq_len_k + self.page_size - 1
403
+ ) // self.page_size
404
+ page_indices = self.req_to_token[
405
+ req_pool_indices[:, None],
406
+ self.draft_extend_metadata["strided_indices"][:max_seq_pages],
407
+ ]
408
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
167
409
  self.forward_metadata = metadata
168
410
 
169
411
  def get_cuda_graph_seq_len_fill_value(self) -> int:
@@ -179,12 +421,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
179
421
  device = seqlens_in_batch.device
180
422
 
181
423
  if forward_batch.forward_mode.is_decode_or_idle():
182
- # Normal Decode
183
- metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
184
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
424
+ if forward_batch.spec_info is not None:
425
+ # Draft Decode
426
+ # Here we only support topk = 1 for now.
427
+ metadata.cache_seqlens_int32 = (
428
+ seqlens_in_batch + (self.speculative_step_id + 1)
429
+ ).to(torch.int32)
430
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
431
+ self.speculative_step_id + 1
432
+ )
433
+ metadata.cu_seqlens_q = torch.arange(
434
+ 0, batch_size + 1, dtype=torch.int32, device=device
435
+ )
436
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
437
+ torch.cumsum(
438
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
439
+ ),
440
+ (1, 0),
441
+ )
442
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
443
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
444
+ ]
445
+ else:
446
+ # Normal Decode
447
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
448
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
449
+ metadata.cu_seqlens_q = torch.arange(
450
+ 0, batch_size + 1, dtype=torch.int32, device=device
451
+ )
452
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
453
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
454
+ )
455
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
456
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
457
+ ]
458
+ elif forward_batch.forward_mode.is_target_verify():
459
+ # Only support topk = 1 for now.
460
+ metadata.cache_seqlens_int32 = (
461
+ forward_batch.seq_lens + self.speculative_num_draft_tokens
462
+ ).to(torch.int32)
463
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
464
+ metadata.max_seq_len_k = (
465
+ forward_batch.seq_lens_cpu.max().item()
466
+ + self.speculative_num_draft_tokens
467
+ )
468
+ metadata.cu_seqlens_q = torch.arange(
469
+ 0,
470
+ batch_size * self.speculative_num_draft_tokens + 1,
471
+ self.speculative_num_draft_tokens,
472
+ dtype=torch.int32,
473
+ device=device,
474
+ )
475
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
476
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
477
+ (1, 0),
478
+ )
185
479
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
186
480
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
187
481
  ]
482
+
188
483
  else:
189
484
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
190
485
  metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
@@ -195,7 +490,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
195
490
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
196
491
  ]
197
492
 
198
- if any(forward_batch.extend_prefix_lens_cpu):
493
+ if (
494
+ any(forward_batch.extend_prefix_lens_cpu)
495
+ or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
496
+ ):
199
497
  extend_seq_lens = forward_batch.extend_seq_lens
200
498
  metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
201
499
  metadata.cu_seqlens_q = torch.nn.functional.pad(
@@ -265,7 +563,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
265
563
  workspace_buffer=self.workspace_buffer,
266
564
  block_tables=self.forward_metadata.page_table,
267
565
  seq_lens=self.forward_metadata.cache_seqlens_int32,
268
- max_seq_len=self.forward_metadata.max_seq_len_k,
566
+ max_seq_len=self.max_context_len,
269
567
  bmm1_scale=bmm1_scale,
270
568
  bmm2_scale=bmm2_scale,
271
569
  window_left=layer.sliding_window_size,
@@ -320,7 +618,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
320
618
  block_tables=self.forward_metadata.page_table,
321
619
  seq_lens=self.forward_metadata.cache_seqlens_int32,
322
620
  max_q_len=self.forward_metadata.max_seq_len_q,
323
- max_kv_len=self.forward_metadata.max_seq_len_k,
621
+ max_kv_len=self.max_context_len,
324
622
  bmm1_scale=bmm1_scale,
325
623
  bmm2_scale=bmm2_scale,
326
624
  batch_size=forward_batch.batch_size,
@@ -332,3 +630,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
332
630
  )
333
631
 
334
632
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
633
+
634
+
635
+ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
636
+ """Multi-step TRTLLM MHA attention kernel used by EAGLE."""
637
+
638
+ def __init__(
639
+ self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
640
+ ):
641
+ super().__init__(model_runner, topk, speculative_num_steps)
642
+ for i in range(speculative_num_steps):
643
+ self.attn_backends[i] = TRTLLMHAAttnBackend(
644
+ model_runner,
645
+ skip_prefill=True,
646
+ kv_indptr_buf=self.kv_indptr[i],
647
+ kv_last_page_len_buf=self.kv_last_page_len,
648
+ speculative_step_id=i,
649
+ )
650
+
651
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
652
+ for i in range(self.speculative_num_steps - 1):
653
+ self.attn_backends[i].init_forward_metadata(forward_batch)
654
+
655
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
656
+ for i in range(self.speculative_num_steps):
657
+ self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
658
+
659
+ def init_forward_metadata_capture_cuda_graph(
660
+ self,
661
+ forward_batch: ForwardBatch,
662
+ ):
663
+ assert forward_batch.spec_info is not None
664
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
665
+
666
+ for i in range(self.speculative_num_steps - 1):
667
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
668
+ forward_batch.batch_size,
669
+ forward_batch.batch_size * self.topk,
670
+ forward_batch.req_pool_indices,
671
+ forward_batch.seq_lens,
672
+ encoder_lens=forward_batch.encoder_lens,
673
+ forward_mode=ForwardMode.DECODE,
674
+ spec_info=forward_batch.spec_info,
675
+ )
676
+
677
+ def init_forward_metadata_replay_cuda_graph(
678
+ self, forward_batch: ForwardBatch, bs: int
679
+ ):
680
+ assert forward_batch.spec_info is not None
681
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
682
+
683
+ for i in range(self.speculative_num_steps - 1):
684
+
685
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
686
+ bs,
687
+ forward_batch.req_pool_indices,
688
+ forward_batch.seq_lens,
689
+ forward_batch.seq_lens_sum,
690
+ encoder_lens=forward_batch.encoder_lens,
691
+ forward_mode=ForwardMode.DECODE,
692
+ spec_info=forward_batch.spec_info,
693
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
694
+ )
@@ -11,7 +11,10 @@ from typing import TYPE_CHECKING, Optional, Union
11
11
  import torch
12
12
  import triton
13
13
 
14
- from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
14
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
15
+ FlashInferMLAAttnBackend,
16
+ FlashInferMLAMultiStepDraftBackend,
17
+ )
15
18
  from sglang.srt.layers.attention.utils import (
16
19
  TRITON_PAD_NUM_PAGE_PER_BLOCK,
17
20
  create_flashmla_kv_indices_triton,
@@ -96,7 +99,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
96
99
 
97
100
  # CUDA graph state
98
101
  self.decode_cuda_graph_metadata = {}
99
- self.cuda_graph_kv_indices = None
102
+ self.decode_cuda_graph_kv_indices = None
100
103
  self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
101
104
 
102
105
  def _calc_padded_blocks(self, max_seq_len: int) -> int:
@@ -167,15 +170,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
167
170
  kv_indices_buf: Optional[torch.Tensor] = None,
168
171
  ):
169
172
  """Initialize CUDA graph state for TRTLLM MLA."""
173
+
170
174
  max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
171
175
 
172
- self.cuda_graph_kv_indices = torch.full(
176
+ self.decode_cuda_graph_kv_indices = torch.full(
173
177
  (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
174
178
  )
175
- self.cuda_graph_workspace = torch.empty(
179
+ self.decode_cuda_graph_workspace = torch.empty(
176
180
  self.workspace_size, dtype=torch.int8, device=self.device
177
181
  )
178
182
 
183
+ super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
184
+
179
185
  def init_forward_metadata_capture_cuda_graph(
180
186
  self,
181
187
  bs: int,
@@ -187,8 +193,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
187
193
  spec_info: Optional[SpecInfo],
188
194
  ):
189
195
  """Initialize metadata for CUDA graph capture."""
190
- # Delegate to parent for non-decode modes or when speculative execution is used.
191
- if not (forward_mode.is_decode_or_idle() and spec_info is None):
196
+
197
+ # Delegate to parent for non-decode modes.
198
+ if not forward_mode.is_decode_or_idle():
192
199
  return super().init_forward_metadata_capture_cuda_graph(
193
200
  bs,
194
201
  num_tokens,
@@ -199,9 +206,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
199
206
  spec_info,
200
207
  )
201
208
 
202
- # Custom fast-path for decode/idle without speculative execution.
209
+ # Custom fast-path for decode/idle.
203
210
  max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
204
- block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
211
+ block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_seqlen_pad]
205
212
 
206
213
  create_flashmla_kv_indices_triton[(bs,)](
207
214
  self.req_to_token,
@@ -215,7 +222,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
215
222
  PAGED_SIZE=self.page_size,
216
223
  )
217
224
 
218
- metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
225
+ metadata = TRTLLMMLADecodeMetadata(
226
+ self.decode_cuda_graph_workspace, block_kv_indices
227
+ )
219
228
  self.decode_cuda_graph_metadata[bs] = metadata
220
229
  self.forward_metadata = metadata
221
230
 
@@ -231,8 +240,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
231
240
  seq_lens_cpu: Optional[torch.Tensor],
232
241
  ):
233
242
  """Replay CUDA graph with new inputs."""
234
- # Delegate to parent for non-decode modes or when speculative execution is used.
235
- if not (forward_mode.is_decode_or_idle() and spec_info is None):
243
+ # Delegate to parent for non-decode modes.
244
+ if not forward_mode.is_decode_or_idle():
236
245
  return super().init_forward_metadata_replay_cuda_graph(
237
246
  bs,
238
247
  req_pool_indices,
@@ -265,11 +274,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
265
274
 
266
275
  def init_forward_metadata(self, forward_batch: ForwardBatch):
267
276
  """Initialize the metadata for a forward pass."""
268
- # Delegate to parent for non-decode modes or when speculative execution is used.
269
- if not (
270
- forward_batch.forward_mode.is_decode_or_idle()
271
- and forward_batch.spec_info is None
272
- ):
277
+ # Delegate to parent for non-decode modes.
278
+ if not forward_batch.forward_mode.is_decode_or_idle():
273
279
  return super().init_forward_metadata(forward_batch)
274
280
 
275
281
  bs = forward_batch.batch_size
@@ -474,3 +480,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
474
480
  output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
475
481
 
476
482
  return output
483
+
484
+
485
+ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
486
+ """Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
487
+
488
+ def __init__(
489
+ self, model_runner: "ModelRunner", topk: int, speculative_num_steps: int
490
+ ):
491
+ super().__init__(model_runner, topk, speculative_num_steps)
492
+
493
+ for i in range(self.speculative_num_steps):
494
+ self.attn_backends[i] = TRTLLMMLABackend(
495
+ model_runner,
496
+ skip_prefill=True,
497
+ kv_indptr_buf=self.kv_indptr[i],
498
+ q_indptr_decode_buf=self.q_indptr_decode,
499
+ )