sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -32,7 +32,7 @@ try:
32
32
  mha_batch_prefill_func,
33
33
  paged_attention_ragged,
34
34
  )
35
- from aiter.mla import mla_decode_fwd
35
+ from aiter.mla import mla_decode_fwd, mla_prefill_fwd
36
36
  except ImportError:
37
37
  print(
38
38
  "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
@@ -52,10 +52,8 @@ class ForwardMetadata:
52
52
  kv_indices: torch.Tensor
53
53
  qo_indptr: torch.Tensor
54
54
  kv_last_page_len: torch.Tensor
55
- max_extend_len: int
56
- max_prefix_extend_len: int
57
55
  max_q_len: int
58
- max_kv_len: int
56
+ max_kv_len: Optional[int]
59
57
 
60
58
 
61
59
  global_workspace_buffer = None
@@ -71,10 +69,17 @@ class AiterAttnBackend(AttentionBackend):
71
69
  kv_indptr_buf: Optional[torch.Tensor] = None,
72
70
  ):
73
71
  super().__init__()
72
+ # Lazy import to avoid the initialization of cuda context
73
+ from sglang.srt.layers.attention.triton_ops.extend_attention import (
74
+ extend_attention_fwd,
75
+ )
76
+
77
+ self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
74
78
 
75
79
  self.device = model_runner.device
76
80
  self.is_multimodal = model_runner.model_config.is_multimodal
77
81
  self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
82
+ self.speculative_num_steps = model_runner.server_args.speculative_num_steps
78
83
  self.num_head = (
79
84
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
80
85
  )
@@ -157,13 +162,13 @@ class AiterAttnBackend(AttentionBackend):
157
162
  spec_info = forward_batch.spec_info
158
163
  qo_indptr = None
159
164
  kv_last_page_len = None
160
- max_extend_len = None
165
+ max_q_len = None
161
166
 
162
167
  if forward_batch.forward_mode.is_decode_or_idle():
163
168
  if spec_info is None:
164
169
  kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
165
170
  kv_indptr = kv_indptr[: bs + 1]
166
- kv_indices = torch.zeros(
171
+ kv_indices = torch.empty(
167
172
  forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
168
173
  )
169
174
  create_flashinfer_kv_indices_triton[(bs,)](
@@ -183,39 +188,35 @@ class AiterAttnBackend(AttentionBackend):
183
188
  qo_indptr = self.qo_indptr_[: bs + 1]
184
189
  qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
185
190
  kv_last_page_len = self.kv_last_page_len[:bs]
186
- max_extend_len = 1
191
+ max_q_len = 1
187
192
 
188
193
  self.forward_metadata = ForwardMetadata(
189
194
  kv_indptr,
190
195
  kv_indices,
191
196
  qo_indptr,
192
197
  kv_last_page_len,
193
- max_extend_len,
194
- None,
195
- None,
198
+ max_q_len,
196
199
  None,
197
200
  )
198
201
 
199
202
  elif forward_batch.forward_mode.is_draft_extend():
200
203
  if self.use_mla:
201
- prefix_lens = forward_batch.extend_prefix_lens
202
- self.mla_indices_updater_prefill.update(
203
- forward_batch.req_pool_indices,
204
- prefix_lens,
205
- prefix_lens.sum().item(),
206
- forward_batch.extend_seq_lens,
207
- encoder_lens=forward_batch.encoder_lens,
208
- spec_info=None,
204
+ kv_indices, kv_indptr, qo_indptr, custom_mask = (
205
+ spec_info.generate_attn_arg_prefill(
206
+ forward_batch.req_pool_indices,
207
+ forward_batch.seq_lens,
208
+ forward_batch.seq_lens_sum,
209
+ self.req_to_token,
210
+ )
209
211
  )
210
212
  self.forward_metadata = ForwardMetadata(
211
- self.mla_indices_updater_prefill.kv_indptr,
212
- self.mla_indices_updater_prefill.kv_indices,
213
- self.mla_indices_updater_prefill.qo_indptr,
214
- self.mla_indices_updater_prefill.kv_last_page_len,
215
- self.mla_indices_updater_prefill.max_extend_len,
216
- self.mla_indices_updater_prefill.max_prefix_extend_len,
217
- None,
218
- None,
213
+ kv_indptr,
214
+ kv_indices,
215
+ qo_indptr,
216
+ # self.mla_indices_updater_prefill.kv_last_page_len,
217
+ self.kv_last_page_len[:bs],
218
+ max(forward_batch.extend_seq_lens_cpu),
219
+ forward_batch.seq_lens_cpu.max().item(),
219
220
  )
220
221
  else:
221
222
  self.indices_updater_prefill.update(
@@ -231,30 +232,47 @@ class AiterAttnBackend(AttentionBackend):
231
232
  self.indices_updater_prefill.kv_indices,
232
233
  None,
233
234
  None,
234
- None,
235
- None,
236
235
  self.indices_updater_prefill.max_q_len,
237
236
  self.indices_updater_prefill.max_kv_len,
238
237
  )
239
238
  elif forward_batch.forward_mode.is_target_verify():
240
239
  if self.use_mla:
241
- prefix_lens = forward_batch.extend_prefix_lens
242
- self.mla_indices_updater_prefill.update(
240
+ draft_num = spec_info.draft_token_num
241
+ kv_lens = forward_batch.seq_lens + draft_num
242
+ kv_lens_sum = forward_batch.seq_lens_sum + draft_num * bs
243
+ device = forward_batch.seq_lens.device
244
+
245
+ qo_indptr = torch.arange(
246
+ 0,
247
+ (1 + bs) * draft_num,
248
+ step=draft_num,
249
+ dtype=torch.int32,
250
+ device=device,
251
+ )
252
+ kv_indptr = self.kv_indptr
253
+ kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
254
+ kv_indptr = kv_indptr[: bs + 1]
255
+ kv_indices = torch.empty(
256
+ kv_lens_sum,
257
+ dtype=torch.int32,
258
+ device=device,
259
+ )
260
+ create_flashinfer_kv_indices_triton[(bs,)](
261
+ self.req_to_token,
243
262
  forward_batch.req_pool_indices,
244
- prefix_lens,
245
- prefix_lens.sum().item(),
246
- forward_batch.extend_seq_lens,
247
- encoder_lens=forward_batch.encoder_lens,
248
- spec_info=None,
263
+ kv_lens,
264
+ kv_indptr,
265
+ None,
266
+ kv_indices,
267
+ self.req_to_token.stride(0),
249
268
  )
250
269
  self.forward_metadata = ForwardMetadata(
251
- self.mla_indices_updater_prefill.kv_indptr,
252
- self.mla_indices_updater_prefill.kv_indices,
253
- self.mla_indices_updater_prefill.qo_indptr,
254
- self.mla_indices_updater_prefill.kv_last_page_len,
255
- self.mla_indices_updater_prefill.max_extend_len,
256
- self.mla_indices_updater_prefill.max_prefix_extend_len,
257
- None,
270
+ kv_indptr,
271
+ kv_indices,
272
+ qo_indptr,
273
+ # self.mla_indices_updater_prefill.kv_last_page_len,
274
+ self.kv_last_page_len[:bs],
275
+ draft_num,
258
276
  None,
259
277
  )
260
278
  else:
@@ -271,8 +289,6 @@ class AiterAttnBackend(AttentionBackend):
271
289
  self.indices_updater_prefill.kv_indices,
272
290
  None,
273
291
  None,
274
- None,
275
- None,
276
292
  self.indices_updater_prefill.max_q_len,
277
293
  self.indices_updater_prefill.max_kv_len,
278
294
  )
@@ -283,25 +299,26 @@ class AiterAttnBackend(AttentionBackend):
283
299
  extend_no_prefix = False
284
300
  else:
285
301
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
286
-
287
302
  if self.use_mla:
288
303
  self.mla_indices_updater_prefill.update(
289
304
  forward_batch.req_pool_indices,
290
- prefix_lens,
291
- prefix_lens.sum().item(),
305
+ forward_batch.extend_prefix_lens,
306
+ sum(forward_batch.extend_prefix_lens_cpu),
292
307
  forward_batch.extend_seq_lens,
293
- encoder_lens=forward_batch.encoder_lens,
308
+ max(forward_batch.extend_seq_lens_cpu),
309
+ forward_batch.seq_lens_cpu.max().item(),
294
310
  spec_info=None,
295
311
  )
312
+ self.mla_indices_updater_prefill.kv_indptr += (
313
+ self.mla_indices_updater_prefill.qo_indptr
314
+ )
296
315
  self.forward_metadata = ForwardMetadata(
297
316
  self.mla_indices_updater_prefill.kv_indptr,
298
317
  self.mla_indices_updater_prefill.kv_indices,
299
318
  self.mla_indices_updater_prefill.qo_indptr,
300
- self.mla_indices_updater_prefill.kv_last_page_len,
301
- self.mla_indices_updater_prefill.max_extend_len,
302
- self.mla_indices_updater_prefill.max_prefix_extend_len,
303
- None,
304
- None,
319
+ self.kv_last_page_len[:bs],
320
+ self.mla_indices_updater_prefill.max_q_len,
321
+ self.mla_indices_updater_prefill.max_kv_len,
305
322
  )
306
323
  else:
307
324
  self.indices_updater_prefill.update(
@@ -317,8 +334,6 @@ class AiterAttnBackend(AttentionBackend):
317
334
  self.indices_updater_prefill.kv_indices,
318
335
  None,
319
336
  None,
320
- None,
321
- None,
322
337
  self.indices_updater_prefill.max_q_len,
323
338
  self.indices_updater_prefill.max_kv_len,
324
339
  )
@@ -359,7 +374,7 @@ class AiterAttnBackend(AttentionBackend):
359
374
  if forward_mode.is_decode_or_idle():
360
375
  qo_indptr = None
361
376
  kv_last_page_len = None
362
- max_extend_len = None
377
+ max_q_len = None
363
378
 
364
379
  if spec_info is None:
365
380
  kv_indptr = self.kv_indptr
@@ -383,17 +398,15 @@ class AiterAttnBackend(AttentionBackend):
383
398
  qo_indptr[1 : bs + 1] = torch.cumsum(
384
399
  self.cuda_graph_kv_last_page_len[:bs], dim=0
385
400
  )
386
- max_extend_len = 1
387
401
  kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
402
+ max_q_len = 1
388
403
 
389
404
  self.forward_metadata = ForwardMetadata(
390
405
  kv_indptr,
391
406
  kv_indices,
392
407
  qo_indptr,
393
408
  kv_last_page_len,
394
- max_extend_len,
395
- None,
396
- None,
409
+ max_q_len,
397
410
  None,
398
411
  )
399
412
 
@@ -419,18 +432,15 @@ class AiterAttnBackend(AttentionBackend):
419
432
  kv_indices,
420
433
  self.req_to_token.stride(0),
421
434
  )
422
-
423
- max_extend_len = self.num_draft_tokens
424
- kv_last_page_len = None
435
+ kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
436
+ max_q_len = self.num_draft_tokens
425
437
 
426
438
  self.forward_metadata = ForwardMetadata(
427
439
  kv_indptr,
428
440
  kv_indices,
429
441
  qo_indptr,
430
442
  kv_last_page_len,
431
- max_extend_len,
432
- None,
433
- None,
443
+ max_q_len,
434
444
  None,
435
445
  )
436
446
  else:
@@ -448,12 +458,41 @@ class AiterAttnBackend(AttentionBackend):
448
458
  self.indices_updater_prefill.kv_indices,
449
459
  None,
450
460
  None,
451
- None,
452
- None,
453
461
  self.indices_updater_prefill.max_q_len,
454
462
  self.indices_updater_prefill.max_kv_len,
455
463
  )
456
-
464
+ elif forward_mode.is_draft_extend():
465
+ num_tokens_per_bs = self.speculative_num_steps + 1
466
+ qo_indptr = self.qo_indptr[: bs + 1]
467
+ qo_indptr[: bs + 1] = torch.arange(
468
+ 0,
469
+ bs * num_tokens_per_bs + 1,
470
+ step=num_tokens_per_bs,
471
+ dtype=torch.int32,
472
+ device=self.device,
473
+ )
474
+ kv_indptr = self.kv_indptr[: bs + 1]
475
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
476
+ kv_indices = self.cuda_graph_kv_indices
477
+ create_flashinfer_kv_indices_triton[(bs,)](
478
+ self.req_to_token,
479
+ req_pool_indices,
480
+ seq_lens,
481
+ kv_indptr,
482
+ None,
483
+ kv_indices,
484
+ self.req_to_token.stride(0),
485
+ )
486
+ kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
487
+ max_q_len = num_tokens_per_bs
488
+ self.forward_metadata = ForwardMetadata(
489
+ kv_indptr,
490
+ kv_indices,
491
+ qo_indptr,
492
+ kv_last_page_len,
493
+ max_q_len,
494
+ None,
495
+ )
457
496
  else:
458
497
  raise ValueError(f"Invalid mode: {forward_mode=}")
459
498
 
@@ -488,13 +527,44 @@ class AiterAttnBackend(AttentionBackend):
488
527
  kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
489
528
 
490
529
  elif forward_mode.is_target_verify():
491
- self.indices_updater_prefill.update(
492
- req_pool_indices[:bs],
493
- seq_lens[:bs],
494
- seq_lens_sum,
495
- prefix_lens=None,
496
- encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
497
- spec_info=spec_info,
530
+ bs = len(req_pool_indices)
531
+ qo_indptr = self.qo_indptr[: bs + 1]
532
+ qo_indptr[: bs + 1] = torch.arange(
533
+ 0,
534
+ (1 + bs) * self.num_draft_tokens,
535
+ step=self.num_draft_tokens,
536
+ dtype=torch.int32,
537
+ device=self.device,
538
+ )
539
+ kv_lens = seq_lens + self.num_draft_tokens
540
+ kv_indptr = self.kv_indptr[: bs + 1]
541
+ kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
542
+ kv_indices = self.cuda_graph_kv_indices
543
+ create_flashinfer_kv_indices_triton[(bs,)](
544
+ self.req_to_token,
545
+ req_pool_indices,
546
+ kv_lens,
547
+ kv_indptr,
548
+ None,
549
+ kv_indices,
550
+ self.req_to_token.stride(0),
551
+ )
552
+ elif forward_mode.is_draft_extend():
553
+ seq_lens = seq_lens[:bs]
554
+ accept_lens = spec_info.accept_length[:bs]
555
+ qo_indptr = self.qo_indptr[: bs + 1]
556
+ qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
557
+ kv_indptr = self.kv_indptr[: bs + 1]
558
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
559
+ kv_indices = self.cuda_graph_kv_indices
560
+ create_flashinfer_kv_indices_triton[(bs,)](
561
+ self.req_to_token,
562
+ req_pool_indices,
563
+ seq_lens,
564
+ kv_indptr,
565
+ None,
566
+ kv_indices,
567
+ self.req_to_token.stride(0),
498
568
  )
499
569
  else:
500
570
  raise ValueError("Invalid forward mode")
@@ -530,11 +600,10 @@ class AiterAttnBackend(AttentionBackend):
530
600
  )
531
601
 
532
602
  if self.use_mla:
533
- max_extend_len = self.forward_metadata.max_extend_len
534
- max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len
603
+ max_q_len = self.forward_metadata.max_q_len
604
+ max_kv_len = self.forward_metadata.max_kv_len
535
605
  kv_indptr = self.forward_metadata.kv_indptr
536
606
  kv_indices = self.forward_metadata.kv_indices
537
- kv_last_page_lens = self.forward_metadata.kv_last_page_len
538
607
  qo_indptr = self.forward_metadata.qo_indptr
539
608
  K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
540
609
  V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
@@ -552,8 +621,8 @@ class AiterAttnBackend(AttentionBackend):
552
621
  v,
553
622
  qo_indptr,
554
623
  qo_indptr,
555
- max_extend_len,
556
- max_extend_len,
624
+ max_q_len,
625
+ max_q_len,
557
626
  softmax_scale=layer.scaling,
558
627
  causal=True,
559
628
  )
@@ -599,12 +668,71 @@ class AiterAttnBackend(AttentionBackend):
599
668
  v,
600
669
  qo_indptr,
601
670
  kv_indptr,
602
- max_extend_len,
603
- max_prefix_extend_len,
671
+ max_q_len,
672
+ max_kv_len,
604
673
  softmax_scale=layer.scaling,
605
674
  causal=True,
606
675
  )
607
676
  return o
677
+ elif forward_batch.forward_mode.is_target_verify():
678
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
679
+ mla_decode_fwd(
680
+ q,
681
+ K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
682
+ o,
683
+ self.forward_metadata.qo_indptr,
684
+ self.forward_metadata.kv_indptr,
685
+ self.forward_metadata.kv_indices,
686
+ self.forward_metadata.kv_last_page_len,
687
+ self.forward_metadata.max_q_len,
688
+ layer.scaling,
689
+ layer.logit_cap,
690
+ )
691
+ K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
692
+ return o
693
+ elif forward_batch.forward_mode.is_draft_extend():
694
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
695
+ causal = True
696
+ sliding_window_size = -1
697
+ kv_indptr = self.forward_metadata.kv_indptr
698
+ kv_indices = self.forward_metadata.kv_indices
699
+ mla_prefill_fwd(
700
+ q,
701
+ K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
702
+ o,
703
+ self.forward_metadata.qo_indptr,
704
+ self.forward_metadata.kv_indptr,
705
+ self.forward_metadata.kv_indices,
706
+ self.forward_metadata.kv_last_page_len,
707
+ self.forward_metadata.max_q_len,
708
+ layer.scaling,
709
+ layer.logit_cap,
710
+ )
711
+ K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
712
+ return o
713
+ # self.extend_attention_fwd(
714
+ # q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
715
+ # k.contiguous(),
716
+ # v.contiguous(),
717
+ # o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
718
+ # forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
719
+ # forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
720
+ # self.forward_metadata.qo_indptr,
721
+ # kv_indptr,
722
+ # kv_indices,
723
+ # None,
724
+ # causal,
725
+ # None,
726
+ # self.forward_metadata.max_q_len,
727
+ # layer.scaling,
728
+ # layer.logit_cap,
729
+ # sliding_window_size,
730
+ # )
731
+ # return o
732
+ else:
733
+ raise ValueError(
734
+ f"Invalid forward mode for MLA prefill: {forward_batch.forward_mode=}"
735
+ )
608
736
  else:
609
737
  k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
610
738
  layer.layer_id
@@ -662,7 +790,7 @@ class AiterAttnBackend(AttentionBackend):
662
790
  self.forward_metadata.kv_indptr,
663
791
  self.forward_metadata.kv_indices,
664
792
  self.forward_metadata.kv_last_page_len,
665
- self.forward_metadata.max_extend_len,
793
+ self.forward_metadata.max_q_len,
666
794
  layer.scaling,
667
795
  layer.logit_cap,
668
796
  )
@@ -816,16 +944,17 @@ class AiterMlaIndicesUpdaterPrefill:
816
944
  self.kv_indices = None
817
945
  self.qo_indptr = None
818
946
  self.kv_last_page_len = None
819
- self.max_extend_len = 0
820
- self.max_prefix_extend_len = 0
947
+ self.max_q_len = 0
948
+ self.max_kv_len = 0
821
949
 
822
950
  def update(
823
951
  self,
824
952
  req_pool_indices: torch.Tensor,
825
- prefix_lens: torch.Tensor,
826
- prefix_lens_sum: int,
953
+ kv_lens: torch.Tensor,
954
+ kv_lens_sum: int,
827
955
  extend_lens: torch.Tensor,
828
- encoder_lens: Optional[torch.Tensor],
956
+ max_q_len: int,
957
+ max_kv_len: int,
829
958
  spec_info: Optional[SpecInfo],
830
959
  ):
831
960
  # Keep the signature for type checking. It will be assigned during runtime.
@@ -834,33 +963,30 @@ class AiterMlaIndicesUpdaterPrefill:
834
963
  def update_single_wrapper(
835
964
  self,
836
965
  req_pool_indices: torch.Tensor,
837
- prefix_lens: torch.Tensor,
838
- prefix_lens_sum: int,
966
+ kv_lens: torch.Tensor,
967
+ kv_lens_sum: int,
839
968
  extend_lens: torch.Tensor,
840
- encoder_lens: Optional[torch.Tensor],
969
+ max_q_len: int,
970
+ max_kv_len: int,
841
971
  spec_info: Optional[SpecInfo],
842
972
  ):
843
-
844
- paged_kernel_lens = prefix_lens
845
- paged_kernel_lens_sum = prefix_lens_sum
846
-
847
973
  bs = len(req_pool_indices)
848
974
 
849
975
  kv_indptr = self.attn_backend.kv_indptr
850
976
 
851
977
  if spec_info is None:
852
978
  # Normal extend
853
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
979
+ kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
854
980
  kv_indptr = kv_indptr[: bs + 1]
855
981
  kv_indices = torch.empty(
856
- paged_kernel_lens_sum,
982
+ kv_lens_sum,
857
983
  dtype=torch.int32,
858
984
  device=req_pool_indices.device,
859
985
  )
860
986
  create_flashinfer_kv_indices_triton[(bs,)](
861
987
  self.req_to_token,
862
988
  req_pool_indices,
863
- paged_kernel_lens,
989
+ kv_lens,
864
990
  kv_indptr,
865
991
  None,
866
992
  kv_indices,
@@ -870,16 +996,12 @@ class AiterMlaIndicesUpdaterPrefill:
870
996
  qo_indptr = self.attn_backend.qo_indptr
871
997
  qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
872
998
  qo_indptr = qo_indptr[: bs + 1]
873
-
874
- max_extend_len = torch.max(extend_lens).item()
875
- max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
876
- kv_indptr += qo_indptr
877
999
  else:
878
1000
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
879
1001
  spec_info.generate_attn_arg_prefill(
880
1002
  req_pool_indices,
881
- paged_kernel_lens,
882
- paged_kernel_lens_sum,
1003
+ kv_lens,
1004
+ kv_lens_sum,
883
1005
  self.req_to_token,
884
1006
  )
885
1007
  )
@@ -887,5 +1009,146 @@ class AiterMlaIndicesUpdaterPrefill:
887
1009
  self.kv_indptr = kv_indptr
888
1010
  self.kv_indices = kv_indices
889
1011
  self.qo_indptr = qo_indptr
890
- self.max_extend_len = max_extend_len
891
- self.max_prefix_extend_len = max_prefix_extend_len
1012
+ self.max_q_len = max_q_len
1013
+ self.max_kv_len = max_kv_len
1014
+
1015
+
1016
+ class AiterMultiStepDraftBackend:
1017
+ """
1018
+ Wrap multiple triton attention backends as one for multiple consecutive
1019
+ draft decoding steps.
1020
+ """
1021
+
1022
+ def __init__(
1023
+ self,
1024
+ model_runner: ModelRunner,
1025
+ topk: int,
1026
+ speculative_num_steps: int,
1027
+ ):
1028
+ from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
1029
+
1030
+ self.topk = topk
1031
+ self.speculative_num_steps = speculative_num_steps
1032
+ self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
1033
+ max_bs = model_runner.req_to_token_pool.size * self.topk
1034
+ self.kv_indptr = torch.zeros(
1035
+ (
1036
+ self.speculative_num_steps,
1037
+ max_bs + 1,
1038
+ ),
1039
+ dtype=torch.int32,
1040
+ device=model_runner.device,
1041
+ )
1042
+ self.attn_backends = []
1043
+ for i in range(self.speculative_num_steps):
1044
+ self.attn_backends.append(
1045
+ AiterAttnBackend(
1046
+ model_runner,
1047
+ skip_prefill=True,
1048
+ kv_indptr_buf=self.kv_indptr[i],
1049
+ )
1050
+ )
1051
+ self.max_context_len = self.attn_backends[0].max_context_len
1052
+ self.num_head = (
1053
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
1054
+ )
1055
+ self.device = model_runner.device
1056
+ # Cached variables for generate_draft_decode_kv_indices
1057
+ self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
1058
+ self.page_size = model_runner.server_args.page_size
1059
+ assert self.page_size == 1, "Page size must be 1"
1060
+
1061
+ def common_template(
1062
+ self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
1063
+ ):
1064
+ num_seqs = forward_batch.batch_size
1065
+ bs = self.topk * num_seqs
1066
+ seq_lens_sum = forward_batch.seq_lens_sum
1067
+
1068
+ self.generate_draft_decode_kv_indices[
1069
+ (self.speculative_num_steps, num_seqs, self.topk)
1070
+ ](
1071
+ forward_batch.req_pool_indices,
1072
+ forward_batch.req_to_token_pool.req_to_token,
1073
+ forward_batch.seq_lens,
1074
+ kv_indices_buffer,
1075
+ self.kv_indptr,
1076
+ forward_batch.positions,
1077
+ self.pool_len,
1078
+ kv_indices_buffer.shape[1],
1079
+ self.kv_indptr.shape[1],
1080
+ triton.next_power_of_2(num_seqs),
1081
+ triton.next_power_of_2(self.speculative_num_steps),
1082
+ triton.next_power_of_2(bs),
1083
+ self.page_size,
1084
+ )
1085
+
1086
+ for i in range(self.speculative_num_steps):
1087
+ forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
1088
+ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
1089
+ : seq_lens_sum * self.topk + bs * (i + 1)
1090
+ ]
1091
+ call_fn(i, forward_batch)
1092
+
1093
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
1094
+ kv_indices = torch.empty(
1095
+ (
1096
+ self.speculative_num_steps,
1097
+ forward_batch.batch_size * self.topk * self.max_context_len,
1098
+ ),
1099
+ dtype=torch.int32,
1100
+ device=self.device,
1101
+ )
1102
+
1103
+ def call_fn(i, forward_batch):
1104
+ forward_batch.spec_info.kv_indptr = (
1105
+ forward_batch.spec_info.kv_indptr.clone()
1106
+ )
1107
+ forward_batch.spec_info.kv_indices = (
1108
+ forward_batch.spec_info.kv_indices.clone()
1109
+ )
1110
+ self.attn_backends[i].init_forward_metadata(forward_batch)
1111
+
1112
+ self.common_template(forward_batch, kv_indices, call_fn)
1113
+
1114
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
1115
+ self.cuda_graph_kv_indices = torch.zeros(
1116
+ (self.speculative_num_steps, max_num_tokens * self.max_context_len),
1117
+ dtype=torch.int32,
1118
+ device=self.device,
1119
+ )
1120
+ for i in range(self.speculative_num_steps):
1121
+ self.attn_backends[i].init_cuda_graph_state(
1122
+ max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1123
+ )
1124
+
1125
+ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
1126
+ def call_fn(i, forward_batch):
1127
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
1128
+ forward_batch.batch_size,
1129
+ forward_batch.batch_size * self.topk,
1130
+ forward_batch.req_pool_indices,
1131
+ forward_batch.seq_lens,
1132
+ encoder_lens=None,
1133
+ forward_mode=ForwardMode.DECODE,
1134
+ spec_info=forward_batch.spec_info,
1135
+ )
1136
+
1137
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1138
+
1139
+ def init_forward_metadata_replay_cuda_graph(
1140
+ self, forward_batch: ForwardBatch, bs: int
1141
+ ):
1142
+ def call_fn(i, forward_batch):
1143
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1144
+ bs,
1145
+ forward_batch.req_pool_indices,
1146
+ forward_batch.seq_lens,
1147
+ seq_lens_sum=-1,
1148
+ encoder_lens=None,
1149
+ forward_mode=ForwardMode.DECODE,
1150
+ spec_info=forward_batch.spec_info,
1151
+ seq_lens_cpu=None,
1152
+ )
1153
+
1154
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)