sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -19,16 +19,16 @@ import triton
19
19
  import triton.language as tl
20
20
 
21
21
  from sglang.global_config import global_config
22
- from sglang.srt.layers.attention import AttentionBackend
22
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
23
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
23
24
  from sglang.srt.layers.dp_attention import get_attention_tp_size
24
- from sglang.srt.managers.schedule_batch import global_server_args_dict
25
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
26
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
27
  from sglang.srt.utils import is_flashinfer_available
27
28
 
28
29
  if TYPE_CHECKING:
29
30
  from sglang.srt.layers.radix_attention import RadixAttention
30
31
  from sglang.srt.model_executor.model_runner import ModelRunner
31
- from sglang.srt.speculative.spec_info import SpecInfo
32
32
 
33
33
  if is_flashinfer_available():
34
34
  from flashinfer import (
@@ -37,7 +37,7 @@ if is_flashinfer_available():
37
37
  BatchPrefillWithRaggedKVCacheWrapper,
38
38
  )
39
39
  from flashinfer.cascade import merge_state
40
- from flashinfer.mla import BatchMLAPagedAttentionWrapper
40
+ from flashinfer.decode import PosEncodingMode
41
41
 
42
42
 
43
43
  class WrapperDispatch(Enum):
@@ -47,16 +47,12 @@ class WrapperDispatch(Enum):
47
47
 
48
48
  @dataclass
49
49
  class DecodeMetadata:
50
- decode_wrappers: List[
51
- Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
52
- ]
50
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
53
51
 
54
52
 
55
53
  @dataclass
56
54
  class PrefillMetadata:
57
- prefill_wrappers: List[
58
- Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
59
- ]
55
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
60
56
  use_ragged: bool
61
57
  extend_no_prefix: bool
62
58
 
@@ -73,6 +69,7 @@ class FlashInferAttnBackend(AttentionBackend):
73
69
  model_runner: ModelRunner,
74
70
  skip_prefill: bool = False,
75
71
  kv_indptr_buf: Optional[torch.Tensor] = None,
72
+ kv_last_page_len_buf: Optional[torch.Tensor] = None,
76
73
  ):
77
74
  super().__init__()
78
75
 
@@ -109,12 +106,6 @@ class FlashInferAttnBackend(AttentionBackend):
109
106
  if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
110
107
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
111
108
 
112
- self.enable_flashinfer_mla = False
113
- if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
114
- if global_server_args_dict["enable_flashinfer_mla"]:
115
- self.enable_flashinfer_mla = True
116
- global_config.enable_flashinfer_mla = True
117
-
118
109
  # Allocate buffers
119
110
  global global_workspace_buffer
120
111
  if global_workspace_buffer is None:
@@ -124,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend):
124
115
  device=model_runner.device,
125
116
  )
126
117
  self.workspace_buffer = global_workspace_buffer
118
+
127
119
  max_bs = model_runner.req_to_token_pool.size
128
120
  if kv_indptr_buf is None:
129
121
  self.kv_indptr = [
@@ -132,24 +124,25 @@ class FlashInferAttnBackend(AttentionBackend):
132
124
  )
133
125
  for _ in range(self.num_wrappers)
134
126
  ]
135
- if self.enable_flashinfer_mla:
136
- self.qo_indptr = [
137
- torch.zeros(
138
- (max_bs + 1,), dtype=torch.int32, device=model_runner.device
139
- )
140
- for _ in range(self.num_wrappers)
141
- ]
142
127
  else:
143
128
  assert self.num_wrappers == 1
144
129
  self.kv_indptr = [kv_indptr_buf]
145
130
 
146
- self.kv_last_page_len = torch.ones(
147
- (max_bs,), dtype=torch.int32, device=model_runner.device
148
- )
149
- self.qo_indptr = [
150
- torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
151
- for _ in range(self.num_wrappers)
152
- ]
131
+ if kv_last_page_len_buf is None:
132
+ self.kv_last_page_len = torch.ones(
133
+ (max_bs,), dtype=torch.int32, device=model_runner.device
134
+ )
135
+ else:
136
+ assert self.num_wrappers == 1
137
+ self.kv_last_page_len = kv_last_page_len_buf
138
+
139
+ if not self.skip_prefill:
140
+ self.qo_indptr = [
141
+ torch.zeros(
142
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
143
+ )
144
+ for _ in range(self.num_wrappers)
145
+ ]
153
146
 
154
147
  self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
155
148
  self.workspace_buffer, "NHD"
@@ -162,48 +155,24 @@ class FlashInferAttnBackend(AttentionBackend):
162
155
  self.decode_wrappers = []
163
156
  for _ in range(self.num_wrappers):
164
157
  if not skip_prefill:
165
- if (
166
- self.enable_flashinfer_mla
167
- and not global_server_args_dict["disable_radix_cache"]
168
- ):
169
- # use mla paged prefill
170
- self.prefill_wrappers_paged.append(
171
- BatchMLAPagedAttentionWrapper(
172
- self.workspace_buffer,
173
- backend="fa2",
174
- )
175
- )
176
- self.prefill_wrappers_verify.append(
177
- BatchMLAPagedAttentionWrapper(
178
- self.workspace_buffer,
179
- backend="fa2",
180
- )
181
- )
182
- else:
183
- self.prefill_wrappers_paged.append(
184
- BatchPrefillWithPagedKVCacheWrapper(
185
- self.workspace_buffer,
186
- "NHD",
187
- backend="fa2",
188
- )
189
- )
190
- self.prefill_wrappers_verify.append(
191
- BatchPrefillWithPagedKVCacheWrapper(
192
- self.workspace_buffer, "NHD"
193
- )
194
- )
195
- if self.enable_flashinfer_mla:
196
- self.decode_wrappers.append(
197
- BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
198
- )
199
- else:
200
- self.decode_wrappers.append(
201
- BatchDecodeWithPagedKVCacheWrapper(
158
+ self.prefill_wrappers_paged.append(
159
+ BatchPrefillWithPagedKVCacheWrapper(
202
160
  self.workspace_buffer,
203
161
  "NHD",
204
- use_tensor_cores=self.decode_use_tensor_cores,
162
+ backend="fa2",
205
163
  )
206
164
  )
165
+ self.prefill_wrappers_verify.append(
166
+ BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
167
+ )
168
+
169
+ self.decode_wrappers.append(
170
+ BatchDecodeWithPagedKVCacheWrapper(
171
+ self.workspace_buffer,
172
+ "NHD",
173
+ use_tensor_cores=self.decode_use_tensor_cores,
174
+ )
175
+ )
207
176
 
208
177
  # Create indices updater
209
178
  if not skip_prefill:
@@ -259,10 +228,7 @@ class FlashInferAttnBackend(AttentionBackend):
259
228
  else:
260
229
  prefix_lens = forward_batch.extend_prefix_lens
261
230
 
262
- if self.is_multimodal or (
263
- self.enable_flashinfer_mla
264
- and not global_server_args_dict["disable_radix_cache"]
265
- ):
231
+ if self.is_multimodal:
266
232
  use_ragged = False
267
233
  extend_no_prefix = False
268
234
  else:
@@ -316,37 +282,25 @@ class FlashInferAttnBackend(AttentionBackend):
316
282
  seq_lens: torch.Tensor,
317
283
  encoder_lens: Optional[torch.Tensor],
318
284
  forward_mode: ForwardMode,
319
- spec_info: Optional[SpecInfo],
285
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
320
286
  ):
321
287
  if forward_mode.is_decode_or_idle():
322
288
  decode_wrappers = []
323
289
  for i in range(self.num_wrappers):
324
- if self.enable_flashinfer_mla:
325
- decode_wrappers.append(
326
- BatchMLAPagedAttentionWrapper(
327
- self.workspace_buffer,
328
- use_cuda_graph=True,
329
- qo_indptr=self.qo_indptr[i][: num_tokens + 1],
330
- kv_indptr=self.kv_indptr[i][: num_tokens + 1],
331
- kv_indices=self.cuda_graph_kv_indices[i],
332
- kv_len_arr=self.kv_last_page_len[:num_tokens],
333
- backend="fa2",
334
- )
335
- )
336
- else:
337
- decode_wrappers.append(
338
- BatchDecodeWithPagedKVCacheWrapper(
339
- self.workspace_buffer,
340
- "NHD",
341
- use_cuda_graph=True,
342
- use_tensor_cores=self.decode_use_tensor_cores,
343
- paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
344
- paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
345
- paged_kv_last_page_len_buffer=self.kv_last_page_len[
346
- :num_tokens
347
- ],
348
- )
290
+ decode_wrappers.append(
291
+ BatchDecodeWithPagedKVCacheWrapper(
292
+ self.workspace_buffer,
293
+ "NHD",
294
+ use_cuda_graph=True,
295
+ use_tensor_cores=self.decode_use_tensor_cores,
296
+ paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
297
+ paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
298
+ paged_kv_last_page_len_buffer=self.kv_last_page_len[
299
+ :num_tokens
300
+ ],
349
301
  )
302
+ )
303
+
350
304
  seq_lens_sum = seq_lens.sum().item()
351
305
  self.indices_updater_decode.update(
352
306
  req_pool_indices,
@@ -398,7 +352,8 @@ class FlashInferAttnBackend(AttentionBackend):
398
352
  seq_lens_sum: int,
399
353
  encoder_lens: Optional[torch.Tensor],
400
354
  forward_mode: ForwardMode,
401
- spec_info: Optional[SpecInfo],
355
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
356
+ seq_lens_cpu: Optional[torch.Tensor],
402
357
  ):
403
358
  if forward_mode.is_decode_or_idle():
404
359
  self.indices_updater_decode.update(
@@ -435,114 +390,64 @@ class FlashInferAttnBackend(AttentionBackend):
435
390
  forward_batch: ForwardBatch,
436
391
  save_kv_cache=True,
437
392
  ):
438
- if global_config.enable_flashinfer_mla:
439
- cache_loc = (
440
- forward_batch.out_cache_loc
441
- if not layer.is_cross_attention
442
- else forward_batch.encoder_out_cache_loc
443
- )
444
-
445
- logits_soft_cap = layer.logit_cap
393
+ prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
394
+ self._get_wrapper_idx(layer)
395
+ ]
396
+ cache_loc = (
397
+ forward_batch.out_cache_loc
398
+ if not layer.is_cross_attention
399
+ else forward_batch.encoder_out_cache_loc
400
+ )
446
401
 
447
- if global_server_args_dict["disable_radix_cache"]:
448
- # use mla ragged prefill
449
- o, _ = self.prefill_wrapper_ragged.forward_return_lse(
450
- q.view(-1, layer.tp_q_head_num, layer.head_dim),
451
- k.view(-1, layer.tp_k_head_num, layer.head_dim),
452
- v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
453
- causal=True,
454
- sm_scale=layer.scaling,
455
- logits_soft_cap=logits_soft_cap,
456
- )
402
+ logits_soft_cap = layer.logit_cap
457
403
 
404
+ if not self.forward_metadata.use_ragged:
405
+ if k is not None:
406
+ assert v is not None
458
407
  if save_kv_cache:
459
408
  forward_batch.token_to_kv_pool.set_kv_buffer(
460
- layer,
461
- cache_loc,
462
- k,
463
- v,
409
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
464
410
  )
465
- else:
466
- # use mla paged prefill
467
- prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
468
- self._get_wrapper_idx(layer)
469
- ]
470
- if k is not None:
471
- assert v is not None
472
- if save_kv_cache:
473
- forward_batch.token_to_kv_pool.set_kv_buffer(
474
- layer, cache_loc, k, v
475
- )
476
- qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
477
- k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
478
-
479
- o = prefill_wrapper_paged.run(
480
- qall[:, :, : layer.v_head_dim],
481
- qall[:, :, layer.v_head_dim :],
482
- k_buf[:, :, : layer.v_head_dim],
483
- k_buf[:, :, layer.v_head_dim :],
484
- )
485
411
 
486
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
412
+ o = prefill_wrapper_paged.forward(
413
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
414
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
415
+ causal=not layer.is_cross_attention,
416
+ sm_scale=layer.scaling,
417
+ window_left=layer.sliding_window_size,
418
+ logits_soft_cap=logits_soft_cap,
419
+ k_scale=layer.k_scale,
420
+ v_scale=layer.v_scale,
421
+ )
487
422
  else:
488
- prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
489
- self._get_wrapper_idx(layer)
490
- ]
491
- cache_loc = (
492
- forward_batch.out_cache_loc
493
- if not layer.is_cross_attention
494
- else forward_batch.encoder_out_cache_loc
423
+ o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
424
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
425
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
426
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
427
+ causal=True,
428
+ sm_scale=layer.scaling,
429
+ logits_soft_cap=logits_soft_cap,
495
430
  )
496
431
 
497
- logits_soft_cap = layer.logit_cap
498
-
499
- if not self.forward_metadata.use_ragged:
500
- if k is not None:
501
- assert v is not None
502
- if save_kv_cache:
503
- forward_batch.token_to_kv_pool.set_kv_buffer(
504
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
505
- )
506
-
507
- o = prefill_wrapper_paged.forward(
432
+ if self.forward_metadata.extend_no_prefix:
433
+ o = o1
434
+ else:
435
+ o2, s2 = prefill_wrapper_paged.forward_return_lse(
508
436
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
509
437
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
510
- causal=not layer.is_cross_attention,
438
+ causal=False,
511
439
  sm_scale=layer.scaling,
512
- window_left=layer.sliding_window_size,
513
- logits_soft_cap=logits_soft_cap,
514
- k_scale=layer.k_scale,
515
- v_scale=layer.v_scale,
440
+ logits_soft_cap=layer.logit_cap,
516
441
  )
517
- else:
518
- o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
519
- q.view(-1, layer.tp_q_head_num, layer.head_dim),
520
- k.view(-1, layer.tp_k_head_num, layer.head_dim),
521
- v.view(-1, layer.tp_v_head_num, layer.head_dim),
522
- causal=True,
523
- sm_scale=layer.scaling,
524
- logits_soft_cap=logits_soft_cap,
525
- )
526
-
527
- if self.forward_metadata.extend_no_prefix:
528
- o = o1
529
- else:
530
- o2, s2 = prefill_wrapper_paged.forward_return_lse(
531
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
532
- forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
533
- causal=False,
534
- sm_scale=layer.scaling,
535
- logits_soft_cap=layer.logit_cap,
536
- )
537
442
 
538
- o, _ = merge_state(o1, s1, o2, s2)
443
+ o, _ = merge_state(o1, s1, o2, s2)
539
444
 
540
- if save_kv_cache:
541
- forward_batch.token_to_kv_pool.set_kv_buffer(
542
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
543
- )
445
+ if save_kv_cache:
446
+ forward_batch.token_to_kv_pool.set_kv_buffer(
447
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
448
+ )
544
449
 
545
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
450
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
546
451
 
547
452
  def forward_decode(
548
453
  self,
@@ -562,45 +467,23 @@ class FlashInferAttnBackend(AttentionBackend):
562
467
  else forward_batch.encoder_out_cache_loc
563
468
  )
564
469
 
565
- if self.enable_flashinfer_mla:
566
- if k is not None:
567
- assert v is not None
568
- if save_kv_cache:
569
- forward_batch.token_to_kv_pool.set_kv_buffer(
570
- layer,
571
- cache_loc,
572
- k,
573
- v,
574
- )
575
- reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
576
- k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
577
- reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
578
- o = decode_wrapper.run(
579
- reshaped_q[:, :, : layer.v_head_dim],
580
- reshaped_q[:, :, layer.v_head_dim :],
581
- reshaped_k[:, :, : layer.v_head_dim],
582
- reshaped_k[:, :, layer.v_head_dim :],
583
- )
584
-
585
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
586
- else:
587
- if k is not None:
588
- assert v is not None
589
- if save_kv_cache:
590
- forward_batch.token_to_kv_pool.set_kv_buffer(
591
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
592
- )
470
+ if k is not None:
471
+ assert v is not None
472
+ if save_kv_cache:
473
+ forward_batch.token_to_kv_pool.set_kv_buffer(
474
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
475
+ )
593
476
 
594
- o = decode_wrapper.forward(
595
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
596
- forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
597
- sm_scale=layer.scaling,
598
- logits_soft_cap=layer.logit_cap,
599
- k_scale=layer.k_scale,
600
- v_scale=layer.v_scale,
601
- )
477
+ o = decode_wrapper.forward(
478
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
479
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
480
+ sm_scale=layer.scaling,
481
+ logits_soft_cap=layer.logit_cap,
482
+ k_scale=layer.k_scale,
483
+ v_scale=layer.v_scale,
484
+ )
602
485
 
603
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
486
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
604
487
 
605
488
  def _get_wrapper_idx(self, layer: RadixAttention):
606
489
  if self.num_wrappers == 1:
@@ -648,11 +531,9 @@ class FlashInferIndicesUpdaterDecode:
648
531
  req_pool_indices: torch.Tensor,
649
532
  seq_lens: torch.Tensor,
650
533
  seq_lens_sum: int,
651
- decode_wrappers: List[
652
- Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
653
- ],
534
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
654
535
  encoder_lens: Optional[torch.Tensor],
655
- spec_info: Optional[SpecInfo],
536
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
656
537
  ):
657
538
  # Keep the signature for type checking. It will be assigned during runtime.
658
539
  raise NotImplementedError()
@@ -662,11 +543,9 @@ class FlashInferIndicesUpdaterDecode:
662
543
  req_pool_indices: torch.Tensor,
663
544
  seq_lens: torch.Tensor,
664
545
  seq_lens_sum: int,
665
- decode_wrappers: List[
666
- Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
667
- ],
546
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
668
547
  encoder_lens: Optional[torch.Tensor],
669
- spec_info: Optional[SpecInfo],
548
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
670
549
  ):
671
550
  decode_wrappers = decode_wrappers or self.decode_wrappers
672
551
  self.call_begin_forward(
@@ -686,7 +565,7 @@ class FlashInferIndicesUpdaterDecode:
686
565
  seq_lens_sum: int,
687
566
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
688
567
  encoder_lens: Optional[torch.Tensor],
689
- spec_info: Optional[SpecInfo],
568
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
690
569
  ):
691
570
  for wrapper_id in range(2):
692
571
  if wrapper_id == 0:
@@ -720,7 +599,7 @@ class FlashInferIndicesUpdaterDecode:
720
599
  seq_lens_sum: int,
721
600
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
722
601
  encoder_lens: Optional[torch.Tensor],
723
- spec_info: Optional[SpecInfo],
602
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
724
603
  ):
725
604
  for wrapper_id in range(2):
726
605
  if wrapper_id == 0:
@@ -745,15 +624,13 @@ class FlashInferIndicesUpdaterDecode:
745
624
 
746
625
  def call_begin_forward(
747
626
  self,
748
- wrapper: Union[
749
- BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
750
- ],
627
+ wrapper: BatchDecodeWithPagedKVCacheWrapper,
751
628
  req_pool_indices: torch.Tensor,
752
629
  paged_kernel_lens: torch.Tensor,
753
630
  paged_kernel_lens_sum: int,
754
631
  kv_indptr: torch.Tensor,
755
632
  kv_start_idx: torch.Tensor,
756
- spec_info: Optional[SpecInfo],
633
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
757
634
  ):
758
635
  if spec_info is None:
759
636
  bs = len(req_pool_indices)
@@ -772,40 +649,21 @@ class FlashInferIndicesUpdaterDecode:
772
649
  self.req_to_token.shape[1],
773
650
  )
774
651
  else:
652
+ assert isinstance(spec_info, EagleDraftInput)
775
653
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
776
654
  bs = kv_indptr.shape[0] - 1
777
-
778
- if global_config.enable_flashinfer_mla:
779
- sm_scale = 1.0 / math.sqrt(192)
780
- q_indptr = torch.arange(0, bs + 1).to(0).int()
781
- kv_lens = paged_kernel_lens.to(torch.int32)
782
- wrapper.plan(
783
- q_indptr,
784
- kv_indptr,
785
- kv_indices,
786
- kv_lens,
787
- self.num_qo_heads,
788
- 512,
789
- 64,
790
- 1,
791
- False,
792
- sm_scale,
793
- self.data_type,
794
- self.data_type,
795
- )
796
- else:
797
- wrapper.begin_forward(
798
- kv_indptr,
799
- kv_indices,
800
- self.kv_last_page_len[:bs],
801
- self.num_qo_heads,
802
- self.num_kv_heads,
803
- self.head_dim,
804
- 1,
805
- data_type=self.data_type,
806
- q_data_type=self.q_data_type,
807
- non_blocking=True,
808
- )
655
+ wrapper.begin_forward(
656
+ kv_indptr,
657
+ kv_indices,
658
+ self.kv_last_page_len[:bs],
659
+ self.num_qo_heads,
660
+ self.num_kv_heads,
661
+ self.head_dim,
662
+ 1,
663
+ data_type=self.data_type,
664
+ q_data_type=self.q_data_type,
665
+ non_blocking=True,
666
+ )
809
667
 
810
668
 
811
669
  class FlashInferIndicesUpdaterPrefill:
@@ -845,12 +703,10 @@ class FlashInferIndicesUpdaterPrefill:
845
703
  seq_lens: torch.Tensor,
846
704
  seq_lens_sum: int,
847
705
  prefix_lens: torch.Tensor,
848
- prefill_wrappers: List[
849
- Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
850
- ],
706
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
851
707
  use_ragged: bool,
852
708
  encoder_lens: Optional[torch.Tensor],
853
- spec_info: Optional[SpecInfo],
709
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
854
710
  ):
855
711
  # Keep the signature for type checking. It will be assigned during runtime.
856
712
  raise NotImplementedError()
@@ -861,12 +717,10 @@ class FlashInferIndicesUpdaterPrefill:
861
717
  seq_lens: torch.Tensor,
862
718
  seq_lens_sum: int,
863
719
  prefix_lens: torch.Tensor,
864
- prefill_wrappers: List[
865
- Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
866
- ],
720
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
867
721
  use_ragged: bool,
868
722
  encoder_lens: Optional[torch.Tensor],
869
- spec_info: Optional[SpecInfo],
723
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
870
724
  ):
871
725
  if use_ragged:
872
726
  paged_kernel_lens = prefix_lens
@@ -899,7 +753,7 @@ class FlashInferIndicesUpdaterPrefill:
899
753
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
900
754
  use_ragged: bool,
901
755
  encoder_lens: Optional[torch.Tensor],
902
- spec_info: Optional[SpecInfo],
756
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
903
757
  ):
904
758
  for wrapper_id in range(2):
905
759
  if wrapper_id == 0:
@@ -940,7 +794,7 @@ class FlashInferIndicesUpdaterPrefill:
940
794
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
941
795
  use_ragged: bool,
942
796
  encoder_lens: Optional[torch.Tensor],
943
- spec_info: Optional[SpecInfo],
797
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
944
798
  ):
945
799
  for wrapper_id in range(2):
946
800
  if wrapper_id == 0:
@@ -972,9 +826,7 @@ class FlashInferIndicesUpdaterPrefill:
972
826
  def call_begin_forward(
973
827
  self,
974
828
  wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
975
- wrapper_paged: Union[
976
- BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
977
- ],
829
+ wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
978
830
  req_pool_indices: torch.Tensor,
979
831
  paged_kernel_lens: torch.Tensor,
980
832
  paged_kernel_lens_sum: int,
@@ -984,10 +836,11 @@ class FlashInferIndicesUpdaterPrefill:
984
836
  kv_indptr: torch.Tensor,
985
837
  qo_indptr: torch.Tensor,
986
838
  use_ragged: bool,
987
- spec_info: Optional[SpecInfo],
839
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
988
840
  ):
989
- bs = len(req_pool_indices)
841
+ bs = len(seq_lens)
990
842
  if spec_info is None:
843
+ assert len(seq_lens) == len(req_pool_indices)
991
844
  # Normal extend
992
845
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
993
846
  kv_indptr = kv_indptr[: bs + 1]
@@ -1010,72 +863,49 @@ class FlashInferIndicesUpdaterPrefill:
1010
863
  qo_indptr = qo_indptr[: bs + 1]
1011
864
  custom_mask = None
1012
865
  else:
866
+ assert isinstance(spec_info, EagleDraftInput) or isinstance(
867
+ spec_info, EagleVerifyInput
868
+ )
1013
869
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
1014
870
  spec_info.generate_attn_arg_prefill(
1015
871
  req_pool_indices,
1016
872
  paged_kernel_lens,
873
+ paged_kernel_lens_sum,
1017
874
  self.req_to_token,
1018
875
  )
1019
876
  )
1020
877
 
1021
878
  # extend part
1022
879
  if use_ragged:
1023
- if global_config.enable_flashinfer_mla:
1024
- wrapper_ragged.begin_forward(
1025
- qo_indptr=qo_indptr,
1026
- kv_indptr=qo_indptr,
1027
- num_qo_heads=self.num_qo_heads,
1028
- num_kv_heads=self.num_kv_heads,
1029
- head_dim_qk=192,
1030
- head_dim_vo=128,
1031
- q_data_type=self.q_data_type,
1032
- )
1033
- else:
1034
- wrapper_ragged.begin_forward(
1035
- qo_indptr,
1036
- qo_indptr,
1037
- self.num_qo_heads,
1038
- self.num_kv_heads,
1039
- self.head_dim,
1040
- q_data_type=self.q_data_type,
1041
- )
1042
-
1043
- if not global_config.enable_flashinfer_mla:
1044
- # cached part
1045
- wrapper_paged.begin_forward(
880
+ wrapper_ragged.begin_forward(
881
+ qo_indptr,
1046
882
  qo_indptr,
1047
- kv_indptr,
1048
- kv_indices,
1049
- self.kv_last_page_len[:bs],
1050
883
  self.num_qo_heads,
1051
884
  self.num_kv_heads,
1052
885
  self.head_dim,
1053
- 1,
1054
886
  q_data_type=self.q_data_type,
1055
- custom_mask=custom_mask,
1056
- non_blocking=True,
1057
- )
1058
- elif (
1059
- global_config.enable_flashinfer_mla
1060
- and not global_server_args_dict["disable_radix_cache"]
1061
- ):
1062
- # mla paged prefill
1063
- kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
1064
- wrapper_paged.plan(
1065
- qo_indptr,
1066
- kv_indptr,
1067
- kv_indices,
1068
- kv_len_arr,
1069
- self.num_qo_heads,
1070
- 512,
1071
- 64,
1072
- 1,
1073
- True,
1074
- 1 / math.sqrt(192),
1075
- self.data_type,
1076
- self.data_type,
1077
887
  )
1078
888
 
889
+ # cached part
890
+ wrapper_paged.begin_forward(
891
+ qo_indptr,
892
+ kv_indptr,
893
+ kv_indices,
894
+ self.kv_last_page_len[:bs],
895
+ self.num_qo_heads,
896
+ self.num_kv_heads,
897
+ self.head_dim,
898
+ 1,
899
+ q_data_type=self.q_data_type,
900
+ custom_mask=custom_mask,
901
+ non_blocking=True,
902
+ )
903
+
904
+
905
+ # Use as a fast path to override the indptr in flashinfer's plan function
906
+ # This is used to remove some host-to-device copy overhead.
907
+ global global_override_indptr_cpu
908
+
1079
909
 
1080
910
  class FlashInferMultiStepDraftBackend:
1081
911
  """
@@ -1094,7 +924,8 @@ class FlashInferMultiStepDraftBackend:
1094
924
  self.topk = topk
1095
925
  self.speculative_num_steps = speculative_num_steps
1096
926
  self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
1097
- max_bs = model_runner.req_to_token_pool.size
927
+
928
+ max_bs = model_runner.req_to_token_pool.size * self.topk
1098
929
  self.kv_indptr = torch.zeros(
1099
930
  (
1100
931
  self.speculative_num_steps,
@@ -1103,6 +934,9 @@ class FlashInferMultiStepDraftBackend:
1103
934
  dtype=torch.int32,
1104
935
  device=model_runner.device,
1105
936
  )
937
+ self.kv_last_page_len = torch.ones(
938
+ (max_bs,), dtype=torch.int32, device=model_runner.device
939
+ )
1106
940
  self.attn_backends = []
1107
941
  for i in range(self.speculative_num_steps):
1108
942
  self.attn_backends.append(
@@ -1110,9 +944,12 @@ class FlashInferMultiStepDraftBackend:
1110
944
  model_runner,
1111
945
  skip_prefill=True,
1112
946
  kv_indptr_buf=self.kv_indptr[i],
947
+ kv_last_page_len_buf=self.kv_last_page_len,
1113
948
  )
1114
949
  )
950
+
1115
951
  self.max_context_len = self.attn_backends[0].max_context_len
952
+
1116
953
  # Cached variables for generate_draft_decode_kv_indices
1117
954
  self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
1118
955
 
@@ -1142,13 +979,23 @@ class FlashInferMultiStepDraftBackend:
1142
979
  triton.next_power_of_2(bs),
1143
980
  )
1144
981
 
982
+ assert forward_batch.spec_info is not None
983
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
984
+
985
+ # Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
986
+ indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
987
+ global global_override_indptr_cpu
988
+
1145
989
  for i in range(self.speculative_num_steps - 1):
1146
990
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
1147
991
  forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
1148
992
  : seq_lens_sum * self.topk + bs * (i + 1)
1149
993
  ]
994
+ global_override_indptr_cpu = indptr_cpu_whole[i]
1150
995
  call_fn(i, forward_batch)
1151
996
 
997
+ global_override_indptr_cpu = None
998
+
1152
999
  def init_forward_metadata(self, forward_batch: ForwardBatch):
1153
1000
  kv_indices = torch.zeros(
1154
1001
  (
@@ -1160,6 +1007,8 @@ class FlashInferMultiStepDraftBackend:
1160
1007
  )
1161
1008
 
1162
1009
  def call_fn(i, forward_batch):
1010
+ assert forward_batch.spec_info is not None
1011
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
1163
1012
  forward_batch.spec_info.kv_indptr = (
1164
1013
  forward_batch.spec_info.kv_indptr.clone()
1165
1014
  )
@@ -1176,6 +1025,7 @@ class FlashInferMultiStepDraftBackend:
1176
1025
  dtype=torch.int32,
1177
1026
  device="cuda",
1178
1027
  )
1028
+
1179
1029
  for i in range(self.speculative_num_steps):
1180
1030
  self.attn_backends[i].init_cuda_graph_state(
1181
1031
  max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
@@ -1209,48 +1059,12 @@ class FlashInferMultiStepDraftBackend:
1209
1059
  encoder_lens=None,
1210
1060
  forward_mode=ForwardMode.DECODE,
1211
1061
  spec_info=forward_batch.spec_info,
1062
+ seq_lens_cpu=None,
1212
1063
  )
1213
1064
 
1214
1065
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1215
1066
 
1216
1067
 
1217
- @triton.jit
1218
- def create_flashinfer_kv_indices_triton(
1219
- req_to_token_ptr, # [max_batch, max_context_len]
1220
- req_pool_indices_ptr,
1221
- page_kernel_lens_ptr,
1222
- kv_indptr,
1223
- kv_start_idx,
1224
- kv_indices_ptr,
1225
- req_to_token_ptr_stride: tl.constexpr,
1226
- ):
1227
- BLOCK_SIZE: tl.constexpr = 512
1228
- pid = tl.program_id(axis=0)
1229
-
1230
- req_pool_index = tl.load(req_pool_indices_ptr + pid)
1231
- kv_indices_offset = tl.load(kv_indptr + pid)
1232
-
1233
- kv_start = 0
1234
- kv_end = 0
1235
- if kv_start_idx:
1236
- kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
1237
- kv_end = kv_start
1238
- kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
1239
-
1240
- num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
1241
- for i in range(num_loop):
1242
- offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
1243
- mask = offset < kv_end - kv_start
1244
- data = tl.load(
1245
- req_to_token_ptr
1246
- + req_pool_index * req_to_token_ptr_stride
1247
- + kv_start
1248
- + offset,
1249
- mask=mask,
1250
- )
1251
- tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
1252
-
1253
-
1254
1068
  def should_use_tensor_core(
1255
1069
  kv_cache_dtype: torch.dtype,
1256
1070
  num_attention_heads: int,
@@ -1272,6 +1086,21 @@ def should_use_tensor_core(
1272
1086
  if env_override is not None:
1273
1087
  return env_override.lower() == "true"
1274
1088
 
1089
+ # Try to use _grouped_size_compiled_for_decode_kernels if available
1090
+ # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
1091
+ try:
1092
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
1093
+
1094
+ if not _grouped_size_compiled_for_decode_kernels(
1095
+ num_attention_heads,
1096
+ num_kv_heads,
1097
+ ):
1098
+ return True
1099
+ else:
1100
+ return False
1101
+ except (ImportError, AttributeError):
1102
+ pass
1103
+
1275
1104
  # Calculate GQA group size
1276
1105
  gqa_group_size = num_attention_heads // num_kv_heads
1277
1106
 
@@ -1301,12 +1130,18 @@ def fast_decode_plan(
1301
1130
  sm_scale: Optional[float] = None,
1302
1131
  rope_scale: Optional[float] = None,
1303
1132
  rope_theta: Optional[float] = None,
1304
- **kwargs,
1133
+ non_blocking: bool = True,
1305
1134
  ) -> None:
1306
- """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
1135
+ """
1136
+ A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
1137
+ Modifications:
1138
+ - Remove unnecessary device-to-device copy for the cuda graph buffers.
1139
+ - Remove unnecessary host-to-device copy for the metadata buffers.
1140
+ """
1307
1141
  batch_size = len(last_page_len)
1308
1142
  if logits_soft_cap is None:
1309
1143
  logits_soft_cap = 0.0
1144
+
1310
1145
  if self.is_cuda_graph_enabled:
1311
1146
  if batch_size != self._fixed_batch_size:
1312
1147
  raise ValueError(
@@ -1319,13 +1154,19 @@ def fast_decode_plan(
1319
1154
  raise ValueError(
1320
1155
  "The size of indices should be less than or equal to the allocated buffer"
1321
1156
  )
1157
+ # Skip these copies
1158
+ # self._paged_kv_indptr_buf.copy_(indptr)
1159
+ # self._paged_kv_indices_buf[: len(indices)] = indices
1160
+ # self._paged_kv_last_page_len_buf.copy_(last_page_len)
1322
1161
  else:
1323
1162
  self._paged_kv_indptr_buf = indptr
1324
1163
  self._paged_kv_indices_buf = indices
1325
1164
  self._paged_kv_last_page_len_buf = last_page_len
1165
+
1326
1166
  # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
1327
1167
  if not q_data_type:
1328
1168
  q_data_type = data_type
1169
+
1329
1170
  if not hasattr(self, "empty_q_data"):
1330
1171
  self.empty_q_data = torch.empty(
1331
1172
  0,
@@ -1342,6 +1183,7 @@ def fast_decode_plan(
1342
1183
  ),
1343
1184
  )
1344
1185
  self.last_page_len = torch.ones(32768, dtype=torch.int32)
1186
+
1345
1187
  empty_q_data = self.empty_q_data
1346
1188
  empty_kv_cache = self.empty_kv_cache
1347
1189
  stream = torch.cuda.current_stream()