sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -19,16 +19,16 @@ import triton
19
19
  import triton.language as tl
20
20
 
21
21
  from sglang.global_config import global_config
22
- from sglang.srt.layers.attention import AttentionBackend
22
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
23
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
23
24
  from sglang.srt.layers.dp_attention import get_attention_tp_size
24
- from sglang.srt.managers.schedule_batch import global_server_args_dict
25
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
26
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
27
  from sglang.srt.utils import is_flashinfer_available
27
28
 
28
29
  if TYPE_CHECKING:
29
30
  from sglang.srt.layers.radix_attention import RadixAttention
30
31
  from sglang.srt.model_executor.model_runner import ModelRunner
31
- from sglang.srt.speculative.spec_info import SpecInfo
32
32
 
33
33
  if is_flashinfer_available():
34
34
  from flashinfer import (
@@ -37,7 +37,7 @@ if is_flashinfer_available():
37
37
  BatchPrefillWithRaggedKVCacheWrapper,
38
38
  )
39
39
  from flashinfer.cascade import merge_state
40
- from flashinfer.mla import BatchMLAPagedAttentionWrapper
40
+ from flashinfer.decode import PosEncodingMode
41
41
 
42
42
 
43
43
  class WrapperDispatch(Enum):
@@ -47,9 +47,7 @@ class WrapperDispatch(Enum):
47
47
 
48
48
  @dataclass
49
49
  class DecodeMetadata:
50
- decode_wrappers: List[
51
- Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
52
- ]
50
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
53
51
 
54
52
 
55
53
  @dataclass
@@ -71,6 +69,7 @@ class FlashInferAttnBackend(AttentionBackend):
71
69
  model_runner: ModelRunner,
72
70
  skip_prefill: bool = False,
73
71
  kv_indptr_buf: Optional[torch.Tensor] = None,
72
+ kv_last_page_len_buf: Optional[torch.Tensor] = None,
74
73
  ):
75
74
  super().__init__()
76
75
 
@@ -107,12 +106,6 @@ class FlashInferAttnBackend(AttentionBackend):
107
106
  if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
108
107
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
109
108
 
110
- self.enable_flashinfer_mla = False
111
- if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
112
- if global_server_args_dict["enable_flashinfer_mla"]:
113
- self.enable_flashinfer_mla = True
114
- global_config.enable_flashinfer_mla = True
115
-
116
109
  # Allocate buffers
117
110
  global global_workspace_buffer
118
111
  if global_workspace_buffer is None:
@@ -122,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend):
122
115
  device=model_runner.device,
123
116
  )
124
117
  self.workspace_buffer = global_workspace_buffer
118
+
125
119
  max_bs = model_runner.req_to_token_pool.size
126
120
  if kv_indptr_buf is None:
127
121
  self.kv_indptr = [
@@ -130,24 +124,25 @@ class FlashInferAttnBackend(AttentionBackend):
130
124
  )
131
125
  for _ in range(self.num_wrappers)
132
126
  ]
133
- if self.enable_flashinfer_mla:
134
- self.qo_indptr = [
135
- torch.zeros(
136
- (max_bs + 1,), dtype=torch.int32, device=model_runner.device
137
- )
138
- for _ in range(self.num_wrappers)
139
- ]
140
127
  else:
141
128
  assert self.num_wrappers == 1
142
129
  self.kv_indptr = [kv_indptr_buf]
143
130
 
144
- self.kv_last_page_len = torch.ones(
145
- (max_bs,), dtype=torch.int32, device=model_runner.device
146
- )
147
- self.qo_indptr = [
148
- torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
149
- for _ in range(self.num_wrappers)
150
- ]
131
+ if kv_last_page_len_buf is None:
132
+ self.kv_last_page_len = torch.ones(
133
+ (max_bs,), dtype=torch.int32, device=model_runner.device
134
+ )
135
+ else:
136
+ assert self.num_wrappers == 1
137
+ self.kv_last_page_len = kv_last_page_len_buf
138
+
139
+ if not self.skip_prefill:
140
+ self.qo_indptr = [
141
+ torch.zeros(
142
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
143
+ )
144
+ for _ in range(self.num_wrappers)
145
+ ]
151
146
 
152
147
  self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
153
148
  self.workspace_buffer, "NHD"
@@ -170,18 +165,14 @@ class FlashInferAttnBackend(AttentionBackend):
170
165
  self.prefill_wrappers_verify.append(
171
166
  BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
172
167
  )
173
- if self.enable_flashinfer_mla:
174
- self.decode_wrappers.append(
175
- BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
176
- )
177
- else:
178
- self.decode_wrappers.append(
179
- BatchDecodeWithPagedKVCacheWrapper(
180
- self.workspace_buffer,
181
- "NHD",
182
- use_tensor_cores=self.decode_use_tensor_cores,
183
- )
168
+
169
+ self.decode_wrappers.append(
170
+ BatchDecodeWithPagedKVCacheWrapper(
171
+ self.workspace_buffer,
172
+ "NHD",
173
+ use_tensor_cores=self.decode_use_tensor_cores,
184
174
  )
175
+ )
185
176
 
186
177
  # Create indices updater
187
178
  if not skip_prefill:
@@ -291,37 +282,25 @@ class FlashInferAttnBackend(AttentionBackend):
291
282
  seq_lens: torch.Tensor,
292
283
  encoder_lens: Optional[torch.Tensor],
293
284
  forward_mode: ForwardMode,
294
- spec_info: Optional[SpecInfo],
285
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
295
286
  ):
296
287
  if forward_mode.is_decode_or_idle():
297
288
  decode_wrappers = []
298
289
  for i in range(self.num_wrappers):
299
- if self.enable_flashinfer_mla:
300
- decode_wrappers.append(
301
- BatchMLAPagedAttentionWrapper(
302
- self.workspace_buffer,
303
- use_cuda_graph=True,
304
- qo_indptr=self.qo_indptr[i][: num_tokens + 1],
305
- kv_indptr=self.kv_indptr[i][: num_tokens + 1],
306
- kv_indices=self.cuda_graph_kv_indices[i],
307
- kv_len_arr=self.kv_last_page_len[:num_tokens],
308
- backend="fa2",
309
- )
310
- )
311
- else:
312
- decode_wrappers.append(
313
- BatchDecodeWithPagedKVCacheWrapper(
314
- self.workspace_buffer,
315
- "NHD",
316
- use_cuda_graph=True,
317
- use_tensor_cores=self.decode_use_tensor_cores,
318
- paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
319
- paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
320
- paged_kv_last_page_len_buffer=self.kv_last_page_len[
321
- :num_tokens
322
- ],
323
- )
290
+ decode_wrappers.append(
291
+ BatchDecodeWithPagedKVCacheWrapper(
292
+ self.workspace_buffer,
293
+ "NHD",
294
+ use_cuda_graph=True,
295
+ use_tensor_cores=self.decode_use_tensor_cores,
296
+ paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
297
+ paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
298
+ paged_kv_last_page_len_buffer=self.kv_last_page_len[
299
+ :num_tokens
300
+ ],
324
301
  )
302
+ )
303
+
325
304
  seq_lens_sum = seq_lens.sum().item()
326
305
  self.indices_updater_decode.update(
327
306
  req_pool_indices,
@@ -373,7 +352,8 @@ class FlashInferAttnBackend(AttentionBackend):
373
352
  seq_lens_sum: int,
374
353
  encoder_lens: Optional[torch.Tensor],
375
354
  forward_mode: ForwardMode,
376
- spec_info: Optional[SpecInfo],
355
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
356
+ seq_lens_cpu: Optional[torch.Tensor],
377
357
  ):
378
358
  if forward_mode.is_decode_or_idle():
379
359
  self.indices_updater_decode.update(
@@ -410,94 +390,64 @@ class FlashInferAttnBackend(AttentionBackend):
410
390
  forward_batch: ForwardBatch,
411
391
  save_kv_cache=True,
412
392
  ):
413
- if global_config.enable_flashinfer_mla:
414
- cache_loc = (
415
- forward_batch.out_cache_loc
416
- if not layer.is_cross_attention
417
- else forward_batch.encoder_out_cache_loc
418
- )
393
+ prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
394
+ self._get_wrapper_idx(layer)
395
+ ]
396
+ cache_loc = (
397
+ forward_batch.out_cache_loc
398
+ if not layer.is_cross_attention
399
+ else forward_batch.encoder_out_cache_loc
400
+ )
419
401
 
420
- logits_soft_cap = layer.logit_cap
402
+ logits_soft_cap = layer.logit_cap
403
+
404
+ if not self.forward_metadata.use_ragged:
405
+ if k is not None:
406
+ assert v is not None
407
+ if save_kv_cache:
408
+ forward_batch.token_to_kv_pool.set_kv_buffer(
409
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
410
+ )
421
411
 
422
- o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
412
+ o = prefill_wrapper_paged.forward(
413
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
414
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
415
+ causal=not layer.is_cross_attention,
416
+ sm_scale=layer.scaling,
417
+ window_left=layer.sliding_window_size,
418
+ logits_soft_cap=logits_soft_cap,
419
+ k_scale=layer.k_scale,
420
+ v_scale=layer.v_scale,
421
+ )
422
+ else:
423
+ o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
423
424
  q.view(-1, layer.tp_q_head_num, layer.head_dim),
424
425
  k.view(-1, layer.tp_k_head_num, layer.head_dim),
425
- v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
426
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
426
427
  causal=True,
427
428
  sm_scale=layer.scaling,
428
429
  logits_soft_cap=logits_soft_cap,
429
430
  )
430
431
 
431
- o = o1
432
-
433
- if save_kv_cache:
434
- forward_batch.token_to_kv_pool.set_kv_buffer(
435
- layer,
436
- cache_loc,
437
- k,
438
- v,
439
- )
440
-
441
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
442
- else:
443
- prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
444
- self._get_wrapper_idx(layer)
445
- ]
446
- cache_loc = (
447
- forward_batch.out_cache_loc
448
- if not layer.is_cross_attention
449
- else forward_batch.encoder_out_cache_loc
450
- )
451
-
452
- logits_soft_cap = layer.logit_cap
453
-
454
- if not self.forward_metadata.use_ragged:
455
- if k is not None:
456
- assert v is not None
457
- if save_kv_cache:
458
- forward_batch.token_to_kv_pool.set_kv_buffer(
459
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
460
- )
461
-
462
- o = prefill_wrapper_paged.forward(
432
+ if self.forward_metadata.extend_no_prefix:
433
+ o = o1
434
+ else:
435
+ o2, s2 = prefill_wrapper_paged.forward_return_lse(
463
436
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
464
437
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
465
- causal=not layer.is_cross_attention,
438
+ causal=False,
466
439
  sm_scale=layer.scaling,
467
- window_left=layer.sliding_window_size,
468
- logits_soft_cap=logits_soft_cap,
469
- k_scale=layer.k_scale,
470
- v_scale=layer.v_scale,
440
+ logits_soft_cap=layer.logit_cap,
471
441
  )
472
- else:
473
- o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
474
- q.view(-1, layer.tp_q_head_num, layer.head_dim),
475
- k.view(-1, layer.tp_k_head_num, layer.head_dim),
476
- v.view(-1, layer.tp_v_head_num, layer.head_dim),
477
- causal=True,
478
- sm_scale=layer.scaling,
479
- logits_soft_cap=logits_soft_cap,
480
- )
481
-
482
- if self.forward_metadata.extend_no_prefix:
483
- o = o1
484
- else:
485
- o2, s2 = prefill_wrapper_paged.forward_return_lse(
486
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
487
- forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
488
- causal=False,
489
- sm_scale=layer.scaling,
490
- logits_soft_cap=layer.logit_cap,
491
- )
492
442
 
493
- o, _ = merge_state(o1, s1, o2, s2)
443
+ o, _ = merge_state(o1, s1, o2, s2)
494
444
 
495
- if save_kv_cache:
496
- forward_batch.token_to_kv_pool.set_kv_buffer(
497
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
498
- )
445
+ if save_kv_cache:
446
+ forward_batch.token_to_kv_pool.set_kv_buffer(
447
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
448
+ )
499
449
 
500
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
450
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
501
451
 
502
452
  def forward_decode(
503
453
  self,
@@ -517,45 +467,23 @@ class FlashInferAttnBackend(AttentionBackend):
517
467
  else forward_batch.encoder_out_cache_loc
518
468
  )
519
469
 
520
- if self.enable_flashinfer_mla:
521
- if k is not None:
522
- assert v is not None
523
- if save_kv_cache:
524
- forward_batch.token_to_kv_pool.set_kv_buffer(
525
- layer,
526
- cache_loc,
527
- k,
528
- v,
529
- )
530
- reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
531
- k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
532
- reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
533
- o = decode_wrapper.run(
534
- reshaped_q[:, :, : layer.v_head_dim],
535
- reshaped_q[:, :, layer.v_head_dim :],
536
- reshaped_k[:, :, : layer.v_head_dim],
537
- reshaped_k[:, :, layer.v_head_dim :],
538
- )
539
-
540
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
541
- else:
542
- if k is not None:
543
- assert v is not None
544
- if save_kv_cache:
545
- forward_batch.token_to_kv_pool.set_kv_buffer(
546
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
547
- )
470
+ if k is not None:
471
+ assert v is not None
472
+ if save_kv_cache:
473
+ forward_batch.token_to_kv_pool.set_kv_buffer(
474
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
475
+ )
548
476
 
549
- o = decode_wrapper.forward(
550
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
551
- forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
552
- sm_scale=layer.scaling,
553
- logits_soft_cap=layer.logit_cap,
554
- k_scale=layer.k_scale,
555
- v_scale=layer.v_scale,
556
- )
477
+ o = decode_wrapper.forward(
478
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
479
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
480
+ sm_scale=layer.scaling,
481
+ logits_soft_cap=layer.logit_cap,
482
+ k_scale=layer.k_scale,
483
+ v_scale=layer.v_scale,
484
+ )
557
485
 
558
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
486
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
559
487
 
560
488
  def _get_wrapper_idx(self, layer: RadixAttention):
561
489
  if self.num_wrappers == 1:
@@ -603,11 +531,9 @@ class FlashInferIndicesUpdaterDecode:
603
531
  req_pool_indices: torch.Tensor,
604
532
  seq_lens: torch.Tensor,
605
533
  seq_lens_sum: int,
606
- decode_wrappers: List[
607
- Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
608
- ],
534
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
609
535
  encoder_lens: Optional[torch.Tensor],
610
- spec_info: Optional[SpecInfo],
536
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
611
537
  ):
612
538
  # Keep the signature for type checking. It will be assigned during runtime.
613
539
  raise NotImplementedError()
@@ -617,11 +543,9 @@ class FlashInferIndicesUpdaterDecode:
617
543
  req_pool_indices: torch.Tensor,
618
544
  seq_lens: torch.Tensor,
619
545
  seq_lens_sum: int,
620
- decode_wrappers: List[
621
- Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
622
- ],
546
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
623
547
  encoder_lens: Optional[torch.Tensor],
624
- spec_info: Optional[SpecInfo],
548
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
625
549
  ):
626
550
  decode_wrappers = decode_wrappers or self.decode_wrappers
627
551
  self.call_begin_forward(
@@ -641,7 +565,7 @@ class FlashInferIndicesUpdaterDecode:
641
565
  seq_lens_sum: int,
642
566
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
643
567
  encoder_lens: Optional[torch.Tensor],
644
- spec_info: Optional[SpecInfo],
568
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
645
569
  ):
646
570
  for wrapper_id in range(2):
647
571
  if wrapper_id == 0:
@@ -675,7 +599,7 @@ class FlashInferIndicesUpdaterDecode:
675
599
  seq_lens_sum: int,
676
600
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
677
601
  encoder_lens: Optional[torch.Tensor],
678
- spec_info: Optional[SpecInfo],
602
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
679
603
  ):
680
604
  for wrapper_id in range(2):
681
605
  if wrapper_id == 0:
@@ -700,15 +624,13 @@ class FlashInferIndicesUpdaterDecode:
700
624
 
701
625
  def call_begin_forward(
702
626
  self,
703
- wrapper: Union[
704
- BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
705
- ],
627
+ wrapper: BatchDecodeWithPagedKVCacheWrapper,
706
628
  req_pool_indices: torch.Tensor,
707
629
  paged_kernel_lens: torch.Tensor,
708
630
  paged_kernel_lens_sum: int,
709
631
  kv_indptr: torch.Tensor,
710
632
  kv_start_idx: torch.Tensor,
711
- spec_info: Optional[SpecInfo],
633
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
712
634
  ):
713
635
  if spec_info is None:
714
636
  bs = len(req_pool_indices)
@@ -727,40 +649,21 @@ class FlashInferIndicesUpdaterDecode:
727
649
  self.req_to_token.shape[1],
728
650
  )
729
651
  else:
652
+ assert isinstance(spec_info, EagleDraftInput)
730
653
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
731
654
  bs = kv_indptr.shape[0] - 1
732
-
733
- if global_config.enable_flashinfer_mla:
734
- sm_scale = 1.0 / math.sqrt(192)
735
- q_indptr = torch.arange(0, bs + 1).to(0).int()
736
- kv_lens = paged_kernel_lens.to(torch.int32)
737
- wrapper.plan(
738
- q_indptr,
739
- kv_indptr,
740
- kv_indices,
741
- kv_lens,
742
- self.num_qo_heads,
743
- 512,
744
- 64,
745
- 1,
746
- False,
747
- sm_scale,
748
- self.data_type,
749
- self.data_type,
750
- )
751
- else:
752
- wrapper.begin_forward(
753
- kv_indptr,
754
- kv_indices,
755
- self.kv_last_page_len[:bs],
756
- self.num_qo_heads,
757
- self.num_kv_heads,
758
- self.head_dim,
759
- 1,
760
- data_type=self.data_type,
761
- q_data_type=self.q_data_type,
762
- non_blocking=True,
763
- )
655
+ wrapper.begin_forward(
656
+ kv_indptr,
657
+ kv_indices,
658
+ self.kv_last_page_len[:bs],
659
+ self.num_qo_heads,
660
+ self.num_kv_heads,
661
+ self.head_dim,
662
+ 1,
663
+ data_type=self.data_type,
664
+ q_data_type=self.q_data_type,
665
+ non_blocking=True,
666
+ )
764
667
 
765
668
 
766
669
  class FlashInferIndicesUpdaterPrefill:
@@ -803,7 +706,7 @@ class FlashInferIndicesUpdaterPrefill:
803
706
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
804
707
  use_ragged: bool,
805
708
  encoder_lens: Optional[torch.Tensor],
806
- spec_info: Optional[SpecInfo],
709
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
807
710
  ):
808
711
  # Keep the signature for type checking. It will be assigned during runtime.
809
712
  raise NotImplementedError()
@@ -817,7 +720,7 @@ class FlashInferIndicesUpdaterPrefill:
817
720
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
818
721
  use_ragged: bool,
819
722
  encoder_lens: Optional[torch.Tensor],
820
- spec_info: Optional[SpecInfo],
723
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
821
724
  ):
822
725
  if use_ragged:
823
726
  paged_kernel_lens = prefix_lens
@@ -850,7 +753,7 @@ class FlashInferIndicesUpdaterPrefill:
850
753
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
851
754
  use_ragged: bool,
852
755
  encoder_lens: Optional[torch.Tensor],
853
- spec_info: Optional[SpecInfo],
756
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
854
757
  ):
855
758
  for wrapper_id in range(2):
856
759
  if wrapper_id == 0:
@@ -891,7 +794,7 @@ class FlashInferIndicesUpdaterPrefill:
891
794
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
892
795
  use_ragged: bool,
893
796
  encoder_lens: Optional[torch.Tensor],
894
- spec_info: Optional[SpecInfo],
797
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
895
798
  ):
896
799
  for wrapper_id in range(2):
897
800
  if wrapper_id == 0:
@@ -933,10 +836,11 @@ class FlashInferIndicesUpdaterPrefill:
933
836
  kv_indptr: torch.Tensor,
934
837
  qo_indptr: torch.Tensor,
935
838
  use_ragged: bool,
936
- spec_info: Optional[SpecInfo],
839
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
937
840
  ):
938
- bs = len(req_pool_indices)
841
+ bs = len(seq_lens)
939
842
  if spec_info is None:
843
+ assert len(seq_lens) == len(req_pool_indices)
940
844
  # Normal extend
941
845
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
942
846
  kv_indptr = kv_indptr[: bs + 1]
@@ -959,52 +863,49 @@ class FlashInferIndicesUpdaterPrefill:
959
863
  qo_indptr = qo_indptr[: bs + 1]
960
864
  custom_mask = None
961
865
  else:
866
+ assert isinstance(spec_info, EagleDraftInput) or isinstance(
867
+ spec_info, EagleVerifyInput
868
+ )
962
869
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
963
870
  spec_info.generate_attn_arg_prefill(
964
871
  req_pool_indices,
965
872
  paged_kernel_lens,
873
+ paged_kernel_lens_sum,
966
874
  self.req_to_token,
967
875
  )
968
876
  )
969
877
 
970
878
  # extend part
971
879
  if use_ragged:
972
- if global_config.enable_flashinfer_mla:
973
- wrapper_ragged.begin_forward(
974
- qo_indptr=qo_indptr,
975
- kv_indptr=qo_indptr,
976
- num_qo_heads=self.num_qo_heads,
977
- num_kv_heads=self.num_kv_heads,
978
- head_dim_qk=192,
979
- head_dim_vo=128,
980
- q_data_type=self.q_data_type,
981
- )
982
- else:
983
- wrapper_ragged.begin_forward(
984
- qo_indptr,
985
- qo_indptr,
986
- self.num_qo_heads,
987
- self.num_kv_heads,
988
- self.head_dim,
989
- q_data_type=self.q_data_type,
990
- )
991
-
992
- if not global_config.enable_flashinfer_mla:
993
- # cached part
994
- wrapper_paged.begin_forward(
880
+ wrapper_ragged.begin_forward(
881
+ qo_indptr,
995
882
  qo_indptr,
996
- kv_indptr,
997
- kv_indices,
998
- self.kv_last_page_len[:bs],
999
883
  self.num_qo_heads,
1000
884
  self.num_kv_heads,
1001
885
  self.head_dim,
1002
- 1,
1003
886
  q_data_type=self.q_data_type,
1004
- custom_mask=custom_mask,
1005
- non_blocking=True,
1006
887
  )
1007
888
 
889
+ # cached part
890
+ wrapper_paged.begin_forward(
891
+ qo_indptr,
892
+ kv_indptr,
893
+ kv_indices,
894
+ self.kv_last_page_len[:bs],
895
+ self.num_qo_heads,
896
+ self.num_kv_heads,
897
+ self.head_dim,
898
+ 1,
899
+ q_data_type=self.q_data_type,
900
+ custom_mask=custom_mask,
901
+ non_blocking=True,
902
+ )
903
+
904
+
905
+ # Use as a fast path to override the indptr in flashinfer's plan function
906
+ # This is used to remove some host-to-device copy overhead.
907
+ global global_override_indptr_cpu
908
+
1008
909
 
1009
910
  class FlashInferMultiStepDraftBackend:
1010
911
  """
@@ -1023,7 +924,8 @@ class FlashInferMultiStepDraftBackend:
1023
924
  self.topk = topk
1024
925
  self.speculative_num_steps = speculative_num_steps
1025
926
  self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
1026
- max_bs = model_runner.req_to_token_pool.size
927
+
928
+ max_bs = model_runner.req_to_token_pool.size * self.topk
1027
929
  self.kv_indptr = torch.zeros(
1028
930
  (
1029
931
  self.speculative_num_steps,
@@ -1032,6 +934,9 @@ class FlashInferMultiStepDraftBackend:
1032
934
  dtype=torch.int32,
1033
935
  device=model_runner.device,
1034
936
  )
937
+ self.kv_last_page_len = torch.ones(
938
+ (max_bs,), dtype=torch.int32, device=model_runner.device
939
+ )
1035
940
  self.attn_backends = []
1036
941
  for i in range(self.speculative_num_steps):
1037
942
  self.attn_backends.append(
@@ -1039,9 +944,12 @@ class FlashInferMultiStepDraftBackend:
1039
944
  model_runner,
1040
945
  skip_prefill=True,
1041
946
  kv_indptr_buf=self.kv_indptr[i],
947
+ kv_last_page_len_buf=self.kv_last_page_len,
1042
948
  )
1043
949
  )
950
+
1044
951
  self.max_context_len = self.attn_backends[0].max_context_len
952
+
1045
953
  # Cached variables for generate_draft_decode_kv_indices
1046
954
  self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
1047
955
 
@@ -1071,13 +979,23 @@ class FlashInferMultiStepDraftBackend:
1071
979
  triton.next_power_of_2(bs),
1072
980
  )
1073
981
 
982
+ assert forward_batch.spec_info is not None
983
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
984
+
985
+ # Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
986
+ indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
987
+ global global_override_indptr_cpu
988
+
1074
989
  for i in range(self.speculative_num_steps - 1):
1075
990
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
1076
991
  forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
1077
992
  : seq_lens_sum * self.topk + bs * (i + 1)
1078
993
  ]
994
+ global_override_indptr_cpu = indptr_cpu_whole[i]
1079
995
  call_fn(i, forward_batch)
1080
996
 
997
+ global_override_indptr_cpu = None
998
+
1081
999
  def init_forward_metadata(self, forward_batch: ForwardBatch):
1082
1000
  kv_indices = torch.zeros(
1083
1001
  (
@@ -1089,6 +1007,8 @@ class FlashInferMultiStepDraftBackend:
1089
1007
  )
1090
1008
 
1091
1009
  def call_fn(i, forward_batch):
1010
+ assert forward_batch.spec_info is not None
1011
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
1092
1012
  forward_batch.spec_info.kv_indptr = (
1093
1013
  forward_batch.spec_info.kv_indptr.clone()
1094
1014
  )
@@ -1105,6 +1025,7 @@ class FlashInferMultiStepDraftBackend:
1105
1025
  dtype=torch.int32,
1106
1026
  device="cuda",
1107
1027
  )
1028
+
1108
1029
  for i in range(self.speculative_num_steps):
1109
1030
  self.attn_backends[i].init_cuda_graph_state(
1110
1031
  max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
@@ -1138,48 +1059,12 @@ class FlashInferMultiStepDraftBackend:
1138
1059
  encoder_lens=None,
1139
1060
  forward_mode=ForwardMode.DECODE,
1140
1061
  spec_info=forward_batch.spec_info,
1062
+ seq_lens_cpu=None,
1141
1063
  )
1142
1064
 
1143
1065
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1144
1066
 
1145
1067
 
1146
- @triton.jit
1147
- def create_flashinfer_kv_indices_triton(
1148
- req_to_token_ptr, # [max_batch, max_context_len]
1149
- req_pool_indices_ptr,
1150
- page_kernel_lens_ptr,
1151
- kv_indptr,
1152
- kv_start_idx,
1153
- kv_indices_ptr,
1154
- req_to_token_ptr_stride: tl.constexpr,
1155
- ):
1156
- BLOCK_SIZE: tl.constexpr = 512
1157
- pid = tl.program_id(axis=0)
1158
-
1159
- req_pool_index = tl.load(req_pool_indices_ptr + pid)
1160
- kv_indices_offset = tl.load(kv_indptr + pid)
1161
-
1162
- kv_start = 0
1163
- kv_end = 0
1164
- if kv_start_idx:
1165
- kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
1166
- kv_end = kv_start
1167
- kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
1168
-
1169
- num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
1170
- for i in range(num_loop):
1171
- offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
1172
- mask = offset < kv_end - kv_start
1173
- data = tl.load(
1174
- req_to_token_ptr
1175
- + req_pool_index * req_to_token_ptr_stride
1176
- + kv_start
1177
- + offset,
1178
- mask=mask,
1179
- )
1180
- tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
1181
-
1182
-
1183
1068
  def should_use_tensor_core(
1184
1069
  kv_cache_dtype: torch.dtype,
1185
1070
  num_attention_heads: int,
@@ -1201,6 +1086,21 @@ def should_use_tensor_core(
1201
1086
  if env_override is not None:
1202
1087
  return env_override.lower() == "true"
1203
1088
 
1089
+ # Try to use _grouped_size_compiled_for_decode_kernels if available
1090
+ # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
1091
+ try:
1092
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
1093
+
1094
+ if not _grouped_size_compiled_for_decode_kernels(
1095
+ num_attention_heads,
1096
+ num_kv_heads,
1097
+ ):
1098
+ return True
1099
+ else:
1100
+ return False
1101
+ except (ImportError, AttributeError):
1102
+ pass
1103
+
1204
1104
  # Calculate GQA group size
1205
1105
  gqa_group_size = num_attention_heads // num_kv_heads
1206
1106
 
@@ -1230,12 +1130,18 @@ def fast_decode_plan(
1230
1130
  sm_scale: Optional[float] = None,
1231
1131
  rope_scale: Optional[float] = None,
1232
1132
  rope_theta: Optional[float] = None,
1233
- **kwargs,
1133
+ non_blocking: bool = True,
1234
1134
  ) -> None:
1235
- """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
1135
+ """
1136
+ A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
1137
+ Modifications:
1138
+ - Remove unnecessary device-to-device copy for the cuda graph buffers.
1139
+ - Remove unnecessary host-to-device copy for the metadata buffers.
1140
+ """
1236
1141
  batch_size = len(last_page_len)
1237
1142
  if logits_soft_cap is None:
1238
1143
  logits_soft_cap = 0.0
1144
+
1239
1145
  if self.is_cuda_graph_enabled:
1240
1146
  if batch_size != self._fixed_batch_size:
1241
1147
  raise ValueError(
@@ -1248,13 +1154,19 @@ def fast_decode_plan(
1248
1154
  raise ValueError(
1249
1155
  "The size of indices should be less than or equal to the allocated buffer"
1250
1156
  )
1157
+ # Skip these copies
1158
+ # self._paged_kv_indptr_buf.copy_(indptr)
1159
+ # self._paged_kv_indices_buf[: len(indices)] = indices
1160
+ # self._paged_kv_last_page_len_buf.copy_(last_page_len)
1251
1161
  else:
1252
1162
  self._paged_kv_indptr_buf = indptr
1253
1163
  self._paged_kv_indices_buf = indices
1254
1164
  self._paged_kv_last_page_len_buf = last_page_len
1165
+
1255
1166
  # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
1256
1167
  if not q_data_type:
1257
1168
  q_data_type = data_type
1169
+
1258
1170
  if not hasattr(self, "empty_q_data"):
1259
1171
  self.empty_q_data = torch.empty(
1260
1172
  0,
@@ -1271,6 +1183,7 @@ def fast_decode_plan(
1271
1183
  ),
1272
1184
  )
1273
1185
  self.last_page_len = torch.ones(32768, dtype=torch.int32)
1186
+
1274
1187
  empty_q_data = self.empty_q_data
1275
1188
  empty_kv_cache = self.empty_kv_cache
1276
1189
  stream = torch.cuda.current_stream()