sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -32,7 +32,7 @@ try:
32
32
  mha_batch_prefill_func,
33
33
  paged_attention_ragged,
34
34
  )
35
- from aiter.mla import mla_decode_fwd
35
+ from aiter.mla import mla_decode_fwd, mla_prefill_fwd
36
36
  except ImportError:
37
37
  print(
38
38
  "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
@@ -52,10 +52,8 @@ class ForwardMetadata:
52
52
  kv_indices: torch.Tensor
53
53
  qo_indptr: torch.Tensor
54
54
  kv_last_page_len: torch.Tensor
55
- max_extend_len: int
56
- max_prefix_extend_len: int
57
55
  max_q_len: int
58
- max_kv_len: int
56
+ max_kv_len: Optional[int]
59
57
 
60
58
 
61
59
  global_workspace_buffer = None
@@ -71,10 +69,17 @@ class AiterAttnBackend(AttentionBackend):
71
69
  kv_indptr_buf: Optional[torch.Tensor] = None,
72
70
  ):
73
71
  super().__init__()
72
+ # Lazy import to avoid the initialization of cuda context
73
+ from sglang.srt.layers.attention.triton_ops.extend_attention import (
74
+ extend_attention_fwd,
75
+ )
76
+
77
+ self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
74
78
 
75
79
  self.device = model_runner.device
76
80
  self.is_multimodal = model_runner.model_config.is_multimodal
77
81
  self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
82
+ self.speculative_num_steps = model_runner.server_args.speculative_num_steps
78
83
  self.num_head = (
79
84
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
80
85
  )
@@ -157,13 +162,13 @@ class AiterAttnBackend(AttentionBackend):
157
162
  spec_info = forward_batch.spec_info
158
163
  qo_indptr = None
159
164
  kv_last_page_len = None
160
- max_extend_len = None
165
+ max_q_len = None
161
166
 
162
167
  if forward_batch.forward_mode.is_decode_or_idle():
163
168
  if spec_info is None:
164
169
  kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
165
170
  kv_indptr = kv_indptr[: bs + 1]
166
- kv_indices = torch.zeros(
171
+ kv_indices = torch.empty(
167
172
  forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
168
173
  )
169
174
  create_flashinfer_kv_indices_triton[(bs,)](
@@ -183,39 +188,35 @@ class AiterAttnBackend(AttentionBackend):
183
188
  qo_indptr = self.qo_indptr_[: bs + 1]
184
189
  qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
185
190
  kv_last_page_len = self.kv_last_page_len[:bs]
186
- max_extend_len = 1
191
+ max_q_len = 1
187
192
 
188
193
  self.forward_metadata = ForwardMetadata(
189
194
  kv_indptr,
190
195
  kv_indices,
191
196
  qo_indptr,
192
197
  kv_last_page_len,
193
- max_extend_len,
194
- None,
195
- None,
198
+ max_q_len,
196
199
  None,
197
200
  )
198
201
 
199
202
  elif forward_batch.forward_mode.is_draft_extend():
200
203
  if self.use_mla:
201
- prefix_lens = forward_batch.extend_prefix_lens
202
- self.mla_indices_updater_prefill.update(
203
- forward_batch.req_pool_indices,
204
- prefix_lens,
205
- prefix_lens.sum().item(),
206
- forward_batch.extend_seq_lens,
207
- encoder_lens=forward_batch.encoder_lens,
208
- spec_info=None,
204
+ kv_indices, kv_indptr, qo_indptr, custom_mask = (
205
+ spec_info.generate_attn_arg_prefill(
206
+ forward_batch.req_pool_indices,
207
+ forward_batch.seq_lens,
208
+ forward_batch.seq_lens_sum,
209
+ self.req_to_token,
210
+ )
209
211
  )
210
212
  self.forward_metadata = ForwardMetadata(
211
- self.mla_indices_updater_prefill.kv_indptr,
212
- self.mla_indices_updater_prefill.kv_indices,
213
- self.mla_indices_updater_prefill.qo_indptr,
214
- self.mla_indices_updater_prefill.kv_last_page_len,
215
- self.mla_indices_updater_prefill.max_extend_len,
216
- self.mla_indices_updater_prefill.max_prefix_extend_len,
217
- None,
218
- None,
213
+ kv_indptr,
214
+ kv_indices,
215
+ qo_indptr,
216
+ # self.mla_indices_updater_prefill.kv_last_page_len,
217
+ self.kv_last_page_len[:bs],
218
+ max(forward_batch.extend_seq_lens_cpu),
219
+ forward_batch.seq_lens_cpu.max().item(),
219
220
  )
220
221
  else:
221
222
  self.indices_updater_prefill.update(
@@ -231,30 +232,47 @@ class AiterAttnBackend(AttentionBackend):
231
232
  self.indices_updater_prefill.kv_indices,
232
233
  None,
233
234
  None,
234
- None,
235
- None,
236
235
  self.indices_updater_prefill.max_q_len,
237
236
  self.indices_updater_prefill.max_kv_len,
238
237
  )
239
238
  elif forward_batch.forward_mode.is_target_verify():
240
239
  if self.use_mla:
241
- prefix_lens = forward_batch.extend_prefix_lens
242
- self.mla_indices_updater_prefill.update(
240
+ draft_num = spec_info.draft_token_num
241
+ kv_lens = forward_batch.seq_lens + draft_num
242
+ kv_lens_sum = forward_batch.seq_lens_sum + draft_num * bs
243
+ device = forward_batch.seq_lens.device
244
+
245
+ qo_indptr = torch.arange(
246
+ 0,
247
+ (1 + bs) * draft_num,
248
+ step=draft_num,
249
+ dtype=torch.int32,
250
+ device=device,
251
+ )
252
+ kv_indptr = self.kv_indptr
253
+ kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
254
+ kv_indptr = kv_indptr[: bs + 1]
255
+ kv_indices = torch.empty(
256
+ kv_lens_sum,
257
+ dtype=torch.int32,
258
+ device=device,
259
+ )
260
+ create_flashinfer_kv_indices_triton[(bs,)](
261
+ self.req_to_token,
243
262
  forward_batch.req_pool_indices,
244
- prefix_lens,
245
- prefix_lens.sum().item(),
246
- forward_batch.extend_seq_lens,
247
- encoder_lens=forward_batch.encoder_lens,
248
- spec_info=None,
263
+ kv_lens,
264
+ kv_indptr,
265
+ None,
266
+ kv_indices,
267
+ self.req_to_token.stride(0),
249
268
  )
250
269
  self.forward_metadata = ForwardMetadata(
251
- self.mla_indices_updater_prefill.kv_indptr,
252
- self.mla_indices_updater_prefill.kv_indices,
253
- self.mla_indices_updater_prefill.qo_indptr,
254
- self.mla_indices_updater_prefill.kv_last_page_len,
255
- self.mla_indices_updater_prefill.max_extend_len,
256
- self.mla_indices_updater_prefill.max_prefix_extend_len,
257
- None,
270
+ kv_indptr,
271
+ kv_indices,
272
+ qo_indptr,
273
+ # self.mla_indices_updater_prefill.kv_last_page_len,
274
+ self.kv_last_page_len[:bs],
275
+ draft_num,
258
276
  None,
259
277
  )
260
278
  else:
@@ -271,8 +289,6 @@ class AiterAttnBackend(AttentionBackend):
271
289
  self.indices_updater_prefill.kv_indices,
272
290
  None,
273
291
  None,
274
- None,
275
- None,
276
292
  self.indices_updater_prefill.max_q_len,
277
293
  self.indices_updater_prefill.max_kv_len,
278
294
  )
@@ -283,25 +299,26 @@ class AiterAttnBackend(AttentionBackend):
283
299
  extend_no_prefix = False
284
300
  else:
285
301
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
286
-
287
302
  if self.use_mla:
288
303
  self.mla_indices_updater_prefill.update(
289
304
  forward_batch.req_pool_indices,
290
- prefix_lens,
291
- prefix_lens.sum().item(),
305
+ forward_batch.extend_prefix_lens,
306
+ sum(forward_batch.extend_prefix_lens_cpu),
292
307
  forward_batch.extend_seq_lens,
293
- encoder_lens=forward_batch.encoder_lens,
308
+ max(forward_batch.extend_seq_lens_cpu),
309
+ forward_batch.seq_lens_cpu.max().item(),
294
310
  spec_info=None,
295
311
  )
312
+ self.mla_indices_updater_prefill.kv_indptr += (
313
+ self.mla_indices_updater_prefill.qo_indptr
314
+ )
296
315
  self.forward_metadata = ForwardMetadata(
297
316
  self.mla_indices_updater_prefill.kv_indptr,
298
317
  self.mla_indices_updater_prefill.kv_indices,
299
318
  self.mla_indices_updater_prefill.qo_indptr,
300
- self.mla_indices_updater_prefill.kv_last_page_len,
301
- self.mla_indices_updater_prefill.max_extend_len,
302
- self.mla_indices_updater_prefill.max_prefix_extend_len,
303
- None,
304
- None,
319
+ self.kv_last_page_len[:bs],
320
+ self.mla_indices_updater_prefill.max_q_len,
321
+ self.mla_indices_updater_prefill.max_kv_len,
305
322
  )
306
323
  else:
307
324
  self.indices_updater_prefill.update(
@@ -317,8 +334,6 @@ class AiterAttnBackend(AttentionBackend):
317
334
  self.indices_updater_prefill.kv_indices,
318
335
  None,
319
336
  None,
320
- None,
321
- None,
322
337
  self.indices_updater_prefill.max_q_len,
323
338
  self.indices_updater_prefill.max_kv_len,
324
339
  )
@@ -359,7 +374,7 @@ class AiterAttnBackend(AttentionBackend):
359
374
  if forward_mode.is_decode_or_idle():
360
375
  qo_indptr = None
361
376
  kv_last_page_len = None
362
- max_extend_len = None
377
+ max_q_len = None
363
378
 
364
379
  if spec_info is None:
365
380
  kv_indptr = self.kv_indptr
@@ -383,17 +398,15 @@ class AiterAttnBackend(AttentionBackend):
383
398
  qo_indptr[1 : bs + 1] = torch.cumsum(
384
399
  self.cuda_graph_kv_last_page_len[:bs], dim=0
385
400
  )
386
- max_extend_len = 1
387
401
  kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
402
+ max_q_len = 1
388
403
 
389
404
  self.forward_metadata = ForwardMetadata(
390
405
  kv_indptr,
391
406
  kv_indices,
392
407
  qo_indptr,
393
408
  kv_last_page_len,
394
- max_extend_len,
395
- None,
396
- None,
409
+ max_q_len,
397
410
  None,
398
411
  )
399
412
 
@@ -419,18 +432,15 @@ class AiterAttnBackend(AttentionBackend):
419
432
  kv_indices,
420
433
  self.req_to_token.stride(0),
421
434
  )
422
-
423
- max_extend_len = self.num_draft_tokens
424
- kv_last_page_len = None
435
+ kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
436
+ max_q_len = self.num_draft_tokens
425
437
 
426
438
  self.forward_metadata = ForwardMetadata(
427
439
  kv_indptr,
428
440
  kv_indices,
429
441
  qo_indptr,
430
442
  kv_last_page_len,
431
- max_extend_len,
432
- None,
433
- None,
443
+ max_q_len,
434
444
  None,
435
445
  )
436
446
  else:
@@ -448,12 +458,41 @@ class AiterAttnBackend(AttentionBackend):
448
458
  self.indices_updater_prefill.kv_indices,
449
459
  None,
450
460
  None,
451
- None,
452
- None,
453
461
  self.indices_updater_prefill.max_q_len,
454
462
  self.indices_updater_prefill.max_kv_len,
455
463
  )
456
-
464
+ elif forward_mode.is_draft_extend():
465
+ num_tokens_per_bs = self.speculative_num_steps + 1
466
+ qo_indptr = self.qo_indptr[: bs + 1]
467
+ qo_indptr[: bs + 1] = torch.arange(
468
+ 0,
469
+ bs * num_tokens_per_bs + 1,
470
+ step=num_tokens_per_bs,
471
+ dtype=torch.int32,
472
+ device=self.device,
473
+ )
474
+ kv_indptr = self.kv_indptr[: bs + 1]
475
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
476
+ kv_indices = self.cuda_graph_kv_indices
477
+ create_flashinfer_kv_indices_triton[(bs,)](
478
+ self.req_to_token,
479
+ req_pool_indices,
480
+ seq_lens,
481
+ kv_indptr,
482
+ None,
483
+ kv_indices,
484
+ self.req_to_token.stride(0),
485
+ )
486
+ kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
487
+ max_q_len = num_tokens_per_bs
488
+ self.forward_metadata = ForwardMetadata(
489
+ kv_indptr,
490
+ kv_indices,
491
+ qo_indptr,
492
+ kv_last_page_len,
493
+ max_q_len,
494
+ None,
495
+ )
457
496
  else:
458
497
  raise ValueError(f"Invalid mode: {forward_mode=}")
459
498
 
@@ -488,13 +527,44 @@ class AiterAttnBackend(AttentionBackend):
488
527
  kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
489
528
 
490
529
  elif forward_mode.is_target_verify():
491
- self.indices_updater_prefill.update(
492
- req_pool_indices[:bs],
493
- seq_lens[:bs],
494
- seq_lens_sum,
495
- prefix_lens=None,
496
- encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
497
- spec_info=spec_info,
530
+ bs = len(req_pool_indices)
531
+ qo_indptr = self.qo_indptr[: bs + 1]
532
+ qo_indptr[: bs + 1] = torch.arange(
533
+ 0,
534
+ (1 + bs) * self.num_draft_tokens,
535
+ step=self.num_draft_tokens,
536
+ dtype=torch.int32,
537
+ device=self.device,
538
+ )
539
+ kv_lens = seq_lens + self.num_draft_tokens
540
+ kv_indptr = self.kv_indptr[: bs + 1]
541
+ kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
542
+ kv_indices = self.cuda_graph_kv_indices
543
+ create_flashinfer_kv_indices_triton[(bs,)](
544
+ self.req_to_token,
545
+ req_pool_indices,
546
+ kv_lens,
547
+ kv_indptr,
548
+ None,
549
+ kv_indices,
550
+ self.req_to_token.stride(0),
551
+ )
552
+ elif forward_mode.is_draft_extend():
553
+ seq_lens = seq_lens[:bs]
554
+ accept_lens = spec_info.accept_length[:bs]
555
+ qo_indptr = self.qo_indptr[: bs + 1]
556
+ qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
557
+ kv_indptr = self.kv_indptr[: bs + 1]
558
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
559
+ kv_indices = self.cuda_graph_kv_indices
560
+ create_flashinfer_kv_indices_triton[(bs,)](
561
+ self.req_to_token,
562
+ req_pool_indices,
563
+ seq_lens,
564
+ kv_indptr,
565
+ None,
566
+ kv_indices,
567
+ self.req_to_token.stride(0),
498
568
  )
499
569
  else:
500
570
  raise ValueError("Invalid forward mode")
@@ -530,11 +600,10 @@ class AiterAttnBackend(AttentionBackend):
530
600
  )
531
601
 
532
602
  if self.use_mla:
533
- max_extend_len = self.forward_metadata.max_extend_len
534
- max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len
603
+ max_q_len = self.forward_metadata.max_q_len
604
+ max_kv_len = self.forward_metadata.max_kv_len
535
605
  kv_indptr = self.forward_metadata.kv_indptr
536
606
  kv_indices = self.forward_metadata.kv_indices
537
- kv_last_page_lens = self.forward_metadata.kv_last_page_len
538
607
  qo_indptr = self.forward_metadata.qo_indptr
539
608
  K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
540
609
  V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
@@ -552,8 +621,8 @@ class AiterAttnBackend(AttentionBackend):
552
621
  v,
553
622
  qo_indptr,
554
623
  qo_indptr,
555
- max_extend_len,
556
- max_extend_len,
624
+ max_q_len,
625
+ max_q_len,
557
626
  softmax_scale=layer.scaling,
558
627
  causal=True,
559
628
  )
@@ -599,12 +668,71 @@ class AiterAttnBackend(AttentionBackend):
599
668
  v,
600
669
  qo_indptr,
601
670
  kv_indptr,
602
- max_extend_len,
603
- max_prefix_extend_len,
671
+ max_q_len,
672
+ max_kv_len,
604
673
  softmax_scale=layer.scaling,
605
674
  causal=True,
606
675
  )
607
676
  return o
677
+ elif forward_batch.forward_mode.is_target_verify():
678
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
679
+ mla_decode_fwd(
680
+ q,
681
+ K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
682
+ o,
683
+ self.forward_metadata.qo_indptr,
684
+ self.forward_metadata.kv_indptr,
685
+ self.forward_metadata.kv_indices,
686
+ self.forward_metadata.kv_last_page_len,
687
+ self.forward_metadata.max_q_len,
688
+ layer.scaling,
689
+ layer.logit_cap,
690
+ )
691
+ K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
692
+ return o
693
+ elif forward_batch.forward_mode.is_draft_extend():
694
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
695
+ causal = True
696
+ sliding_window_size = -1
697
+ kv_indptr = self.forward_metadata.kv_indptr
698
+ kv_indices = self.forward_metadata.kv_indices
699
+ mla_prefill_fwd(
700
+ q,
701
+ K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
702
+ o,
703
+ self.forward_metadata.qo_indptr,
704
+ self.forward_metadata.kv_indptr,
705
+ self.forward_metadata.kv_indices,
706
+ self.forward_metadata.kv_last_page_len,
707
+ self.forward_metadata.max_q_len,
708
+ layer.scaling,
709
+ layer.logit_cap,
710
+ )
711
+ K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
712
+ return o
713
+ # self.extend_attention_fwd(
714
+ # q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
715
+ # k.contiguous(),
716
+ # v.contiguous(),
717
+ # o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
718
+ # forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
719
+ # forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
720
+ # self.forward_metadata.qo_indptr,
721
+ # kv_indptr,
722
+ # kv_indices,
723
+ # None,
724
+ # causal,
725
+ # None,
726
+ # self.forward_metadata.max_q_len,
727
+ # layer.scaling,
728
+ # layer.logit_cap,
729
+ # sliding_window_size,
730
+ # )
731
+ # return o
732
+ else:
733
+ raise ValueError(
734
+ f"Invalid forward mode for MLA prefill: {forward_batch.forward_mode=}"
735
+ )
608
736
  else:
609
737
  k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
610
738
  layer.layer_id
@@ -662,7 +790,7 @@ class AiterAttnBackend(AttentionBackend):
662
790
  self.forward_metadata.kv_indptr,
663
791
  self.forward_metadata.kv_indices,
664
792
  self.forward_metadata.kv_last_page_len,
665
- self.forward_metadata.max_extend_len,
793
+ self.forward_metadata.max_q_len,
666
794
  layer.scaling,
667
795
  layer.logit_cap,
668
796
  )
@@ -720,11 +848,6 @@ class AiterIndicesUpdaterPrefill:
720
848
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
721
849
  self.update = self.update_single_wrapper
722
850
 
723
- # get the last index of the pool
724
- self.pool_size = (
725
- model_runner.token_to_kv_pool.size + model_runner.token_to_kv_pool.page_size
726
- ) - 1
727
-
728
851
  self.kv_indices = None
729
852
  self.max_q_len = 0
730
853
  self.max_kv_len = 0
@@ -769,9 +892,8 @@ class AiterIndicesUpdaterPrefill:
769
892
  # but the 0 location will be made nan (noqa) in cuda graph capture mode
770
893
  # this will cause the output tensor value becomes nan
771
894
  # WA is to assure that last index of pool not changed
772
- kv_indices = torch.full(
773
- (paged_kernel_lens_sum + 128,),
774
- self.pool_size,
895
+ kv_indices = torch.empty(
896
+ paged_kernel_lens_sum + 256,
775
897
  dtype=torch.int32,
776
898
  device=req_pool_indices.device,
777
899
  )
@@ -785,6 +907,9 @@ class AiterIndicesUpdaterPrefill:
785
907
  self.req_to_token.shape[1],
786
908
  )
787
909
 
910
+ token_num = kv_indptr[-1]
911
+ kv_indices[token_num:] = kv_indices[0]
912
+
788
913
  self.max_kv_len = torch.max(paged_kernel_lens).item()
789
914
 
790
915
  extend_lens = seq_lens - prefix_lens
@@ -819,16 +944,17 @@ class AiterMlaIndicesUpdaterPrefill:
819
944
  self.kv_indices = None
820
945
  self.qo_indptr = None
821
946
  self.kv_last_page_len = None
822
- self.max_extend_len = 0
823
- self.max_prefix_extend_len = 0
947
+ self.max_q_len = 0
948
+ self.max_kv_len = 0
824
949
 
825
950
  def update(
826
951
  self,
827
952
  req_pool_indices: torch.Tensor,
828
- prefix_lens: torch.Tensor,
829
- prefix_lens_sum: int,
953
+ kv_lens: torch.Tensor,
954
+ kv_lens_sum: int,
830
955
  extend_lens: torch.Tensor,
831
- encoder_lens: Optional[torch.Tensor],
956
+ max_q_len: int,
957
+ max_kv_len: int,
832
958
  spec_info: Optional[SpecInfo],
833
959
  ):
834
960
  # Keep the signature for type checking. It will be assigned during runtime.
@@ -837,33 +963,30 @@ class AiterMlaIndicesUpdaterPrefill:
837
963
  def update_single_wrapper(
838
964
  self,
839
965
  req_pool_indices: torch.Tensor,
840
- prefix_lens: torch.Tensor,
841
- prefix_lens_sum: int,
966
+ kv_lens: torch.Tensor,
967
+ kv_lens_sum: int,
842
968
  extend_lens: torch.Tensor,
843
- encoder_lens: Optional[torch.Tensor],
969
+ max_q_len: int,
970
+ max_kv_len: int,
844
971
  spec_info: Optional[SpecInfo],
845
972
  ):
846
-
847
- paged_kernel_lens = prefix_lens
848
- paged_kernel_lens_sum = prefix_lens_sum
849
-
850
973
  bs = len(req_pool_indices)
851
974
 
852
975
  kv_indptr = self.attn_backend.kv_indptr
853
976
 
854
977
  if spec_info is None:
855
978
  # Normal extend
856
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
979
+ kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
857
980
  kv_indptr = kv_indptr[: bs + 1]
858
981
  kv_indices = torch.empty(
859
- paged_kernel_lens_sum,
982
+ kv_lens_sum,
860
983
  dtype=torch.int32,
861
984
  device=req_pool_indices.device,
862
985
  )
863
986
  create_flashinfer_kv_indices_triton[(bs,)](
864
987
  self.req_to_token,
865
988
  req_pool_indices,
866
- paged_kernel_lens,
989
+ kv_lens,
867
990
  kv_indptr,
868
991
  None,
869
992
  kv_indices,
@@ -873,16 +996,12 @@ class AiterMlaIndicesUpdaterPrefill:
873
996
  qo_indptr = self.attn_backend.qo_indptr
874
997
  qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
875
998
  qo_indptr = qo_indptr[: bs + 1]
876
-
877
- max_extend_len = torch.max(extend_lens).item()
878
- max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
879
- kv_indptr += qo_indptr
880
999
  else:
881
1000
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
882
1001
  spec_info.generate_attn_arg_prefill(
883
1002
  req_pool_indices,
884
- paged_kernel_lens,
885
- paged_kernel_lens_sum,
1003
+ kv_lens,
1004
+ kv_lens_sum,
886
1005
  self.req_to_token,
887
1006
  )
888
1007
  )
@@ -890,5 +1009,146 @@ class AiterMlaIndicesUpdaterPrefill:
890
1009
  self.kv_indptr = kv_indptr
891
1010
  self.kv_indices = kv_indices
892
1011
  self.qo_indptr = qo_indptr
893
- self.max_extend_len = max_extend_len
894
- self.max_prefix_extend_len = max_prefix_extend_len
1012
+ self.max_q_len = max_q_len
1013
+ self.max_kv_len = max_kv_len
1014
+
1015
+
1016
+ class AiterMultiStepDraftBackend:
1017
+ """
1018
+ Wrap multiple triton attention backends as one for multiple consecutive
1019
+ draft decoding steps.
1020
+ """
1021
+
1022
+ def __init__(
1023
+ self,
1024
+ model_runner: ModelRunner,
1025
+ topk: int,
1026
+ speculative_num_steps: int,
1027
+ ):
1028
+ from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
1029
+
1030
+ self.topk = topk
1031
+ self.speculative_num_steps = speculative_num_steps
1032
+ self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
1033
+ max_bs = model_runner.req_to_token_pool.size * self.topk
1034
+ self.kv_indptr = torch.zeros(
1035
+ (
1036
+ self.speculative_num_steps,
1037
+ max_bs + 1,
1038
+ ),
1039
+ dtype=torch.int32,
1040
+ device=model_runner.device,
1041
+ )
1042
+ self.attn_backends = []
1043
+ for i in range(self.speculative_num_steps):
1044
+ self.attn_backends.append(
1045
+ AiterAttnBackend(
1046
+ model_runner,
1047
+ skip_prefill=True,
1048
+ kv_indptr_buf=self.kv_indptr[i],
1049
+ )
1050
+ )
1051
+ self.max_context_len = self.attn_backends[0].max_context_len
1052
+ self.num_head = (
1053
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
1054
+ )
1055
+ self.device = model_runner.device
1056
+ # Cached variables for generate_draft_decode_kv_indices
1057
+ self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
1058
+ self.page_size = model_runner.server_args.page_size
1059
+ assert self.page_size == 1, "Page size must be 1"
1060
+
1061
+ def common_template(
1062
+ self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
1063
+ ):
1064
+ num_seqs = forward_batch.batch_size
1065
+ bs = self.topk * num_seqs
1066
+ seq_lens_sum = forward_batch.seq_lens_sum
1067
+
1068
+ self.generate_draft_decode_kv_indices[
1069
+ (self.speculative_num_steps, num_seqs, self.topk)
1070
+ ](
1071
+ forward_batch.req_pool_indices,
1072
+ forward_batch.req_to_token_pool.req_to_token,
1073
+ forward_batch.seq_lens,
1074
+ kv_indices_buffer,
1075
+ self.kv_indptr,
1076
+ forward_batch.positions,
1077
+ self.pool_len,
1078
+ kv_indices_buffer.shape[1],
1079
+ self.kv_indptr.shape[1],
1080
+ triton.next_power_of_2(num_seqs),
1081
+ triton.next_power_of_2(self.speculative_num_steps),
1082
+ triton.next_power_of_2(bs),
1083
+ self.page_size,
1084
+ )
1085
+
1086
+ for i in range(self.speculative_num_steps):
1087
+ forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
1088
+ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
1089
+ : seq_lens_sum * self.topk + bs * (i + 1)
1090
+ ]
1091
+ call_fn(i, forward_batch)
1092
+
1093
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
1094
+ kv_indices = torch.empty(
1095
+ (
1096
+ self.speculative_num_steps,
1097
+ forward_batch.batch_size * self.topk * self.max_context_len,
1098
+ ),
1099
+ dtype=torch.int32,
1100
+ device=self.device,
1101
+ )
1102
+
1103
+ def call_fn(i, forward_batch):
1104
+ forward_batch.spec_info.kv_indptr = (
1105
+ forward_batch.spec_info.kv_indptr.clone()
1106
+ )
1107
+ forward_batch.spec_info.kv_indices = (
1108
+ forward_batch.spec_info.kv_indices.clone()
1109
+ )
1110
+ self.attn_backends[i].init_forward_metadata(forward_batch)
1111
+
1112
+ self.common_template(forward_batch, kv_indices, call_fn)
1113
+
1114
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
1115
+ self.cuda_graph_kv_indices = torch.zeros(
1116
+ (self.speculative_num_steps, max_num_tokens * self.max_context_len),
1117
+ dtype=torch.int32,
1118
+ device=self.device,
1119
+ )
1120
+ for i in range(self.speculative_num_steps):
1121
+ self.attn_backends[i].init_cuda_graph_state(
1122
+ max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1123
+ )
1124
+
1125
+ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
1126
+ def call_fn(i, forward_batch):
1127
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
1128
+ forward_batch.batch_size,
1129
+ forward_batch.batch_size * self.topk,
1130
+ forward_batch.req_pool_indices,
1131
+ forward_batch.seq_lens,
1132
+ encoder_lens=None,
1133
+ forward_mode=ForwardMode.DECODE,
1134
+ spec_info=forward_batch.spec_info,
1135
+ )
1136
+
1137
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1138
+
1139
+ def init_forward_metadata_replay_cuda_graph(
1140
+ self, forward_batch: ForwardBatch, bs: int
1141
+ ):
1142
+ def call_fn(i, forward_batch):
1143
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1144
+ bs,
1145
+ forward_batch.req_pool_indices,
1146
+ forward_batch.seq_lens,
1147
+ seq_lens_sum=-1,
1148
+ encoder_lens=None,
1149
+ forward_mode=ForwardMode.DECODE,
1150
+ spec_info=forward_batch.spec_info,
1151
+ seq_lens_cpu=None,
1152
+ )
1153
+
1154
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)