sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -7,28 +7,26 @@ FlashInfer is faster and Triton is easier to customize.
7
7
  Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
8
  """
9
9
 
10
- import math
11
10
  import os
12
11
  from dataclasses import dataclass
13
12
  from enum import Enum, auto
14
13
  from functools import partial
15
- from typing import TYPE_CHECKING, List, Optional, Union
14
+ from typing import TYPE_CHECKING, Callable, List, Optional, Union
16
15
 
17
16
  import torch
18
17
  import triton
19
- import triton.language as tl
20
18
 
21
19
  from sglang.global_config import global_config
22
- from sglang.srt.layers.attention import AttentionBackend
20
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
21
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
23
22
  from sglang.srt.layers.dp_attention import get_attention_tp_size
24
- from sglang.srt.managers.schedule_batch import global_server_args_dict
25
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
24
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
25
  from sglang.srt.utils import is_flashinfer_available
27
26
 
28
27
  if TYPE_CHECKING:
29
28
  from sglang.srt.layers.radix_attention import RadixAttention
30
29
  from sglang.srt.model_executor.model_runner import ModelRunner
31
- from sglang.srt.speculative.spec_info import SpecInfo
32
30
 
33
31
  if is_flashinfer_available():
34
32
  from flashinfer import (
@@ -37,7 +35,7 @@ if is_flashinfer_available():
37
35
  BatchPrefillWithRaggedKVCacheWrapper,
38
36
  )
39
37
  from flashinfer.cascade import merge_state
40
- from flashinfer.mla import BatchMLAPagedAttentionWrapper
38
+ from flashinfer.decode import _get_range_buf, get_seq_lens
41
39
 
42
40
 
43
41
  class WrapperDispatch(Enum):
@@ -47,16 +45,12 @@ class WrapperDispatch(Enum):
47
45
 
48
46
  @dataclass
49
47
  class DecodeMetadata:
50
- decode_wrappers: List[
51
- Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
52
- ]
48
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
53
49
 
54
50
 
55
51
  @dataclass
56
52
  class PrefillMetadata:
57
- prefill_wrappers: List[
58
- Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
59
- ]
53
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
60
54
  use_ragged: bool
61
55
  extend_no_prefix: bool
62
56
 
@@ -73,11 +67,10 @@ class FlashInferAttnBackend(AttentionBackend):
73
67
  model_runner: ModelRunner,
74
68
  skip_prefill: bool = False,
75
69
  kv_indptr_buf: Optional[torch.Tensor] = None,
70
+ kv_last_page_len_buf: Optional[torch.Tensor] = None,
76
71
  ):
77
72
  super().__init__()
78
73
 
79
- self.is_multimodal = model_runner.model_config.is_multimodal
80
-
81
74
  # Parse constants
82
75
  self.decode_use_tensor_cores = should_use_tensor_core(
83
76
  kv_cache_dtype=model_runner.kv_cache_dtype,
@@ -89,6 +82,7 @@ class FlashInferAttnBackend(AttentionBackend):
89
82
  )
90
83
  self.max_context_len = model_runner.model_config.context_len
91
84
  self.skip_prefill = skip_prefill
85
+ self.is_multimodal = model_runner.model_config.is_multimodal
92
86
 
93
87
  assert not (
94
88
  model_runner.sliding_window_size is not None
@@ -109,12 +103,6 @@ class FlashInferAttnBackend(AttentionBackend):
109
103
  if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
110
104
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
111
105
 
112
- self.enable_flashinfer_mla = False
113
- if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
114
- if global_server_args_dict["enable_flashinfer_mla"]:
115
- self.enable_flashinfer_mla = True
116
- global_config.enable_flashinfer_mla = True
117
-
118
106
  # Allocate buffers
119
107
  global global_workspace_buffer
120
108
  if global_workspace_buffer is None:
@@ -132,24 +120,25 @@ class FlashInferAttnBackend(AttentionBackend):
132
120
  )
133
121
  for _ in range(self.num_wrappers)
134
122
  ]
135
- if self.enable_flashinfer_mla:
136
- self.qo_indptr = [
137
- torch.zeros(
138
- (max_bs + 1,), dtype=torch.int32, device=model_runner.device
139
- )
140
- for _ in range(self.num_wrappers)
141
- ]
142
123
  else:
143
124
  assert self.num_wrappers == 1
144
125
  self.kv_indptr = [kv_indptr_buf]
145
126
 
146
- self.kv_last_page_len = torch.ones(
147
- (max_bs,), dtype=torch.int32, device=model_runner.device
148
- )
149
- self.qo_indptr = [
150
- torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
151
- for _ in range(self.num_wrappers)
152
- ]
127
+ if kv_last_page_len_buf is None:
128
+ self.kv_last_page_len = torch.ones(
129
+ (max_bs,), dtype=torch.int32, device=model_runner.device
130
+ )
131
+ else:
132
+ assert self.num_wrappers == 1
133
+ self.kv_last_page_len = kv_last_page_len_buf
134
+
135
+ if not self.skip_prefill:
136
+ self.qo_indptr = [
137
+ torch.zeros(
138
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
139
+ )
140
+ for _ in range(self.num_wrappers)
141
+ ]
153
142
 
154
143
  self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
155
144
  self.workspace_buffer, "NHD"
@@ -162,60 +151,39 @@ class FlashInferAttnBackend(AttentionBackend):
162
151
  self.decode_wrappers = []
163
152
  for _ in range(self.num_wrappers):
164
153
  if not skip_prefill:
165
- if (
166
- self.enable_flashinfer_mla
167
- and not global_server_args_dict["disable_radix_cache"]
168
- ):
169
- # use mla paged prefill
170
- self.prefill_wrappers_paged.append(
171
- BatchMLAPagedAttentionWrapper(
172
- self.workspace_buffer,
173
- backend="fa2",
174
- )
175
- )
176
- self.prefill_wrappers_verify.append(
177
- BatchMLAPagedAttentionWrapper(
178
- self.workspace_buffer,
179
- backend="fa2",
180
- )
181
- )
182
- else:
183
- self.prefill_wrappers_paged.append(
184
- BatchPrefillWithPagedKVCacheWrapper(
185
- self.workspace_buffer,
186
- "NHD",
187
- backend="fa2",
188
- )
189
- )
190
- self.prefill_wrappers_verify.append(
191
- BatchPrefillWithPagedKVCacheWrapper(
192
- self.workspace_buffer, "NHD"
193
- )
154
+ self.prefill_wrappers_paged.append(
155
+ BatchPrefillWithPagedKVCacheWrapper(
156
+ self.workspace_buffer,
157
+ "NHD",
158
+ backend="fa2",
194
159
  )
195
- if self.enable_flashinfer_mla:
196
- self.decode_wrappers.append(
197
- BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
198
160
  )
199
- else:
200
- self.decode_wrappers.append(
201
- BatchDecodeWithPagedKVCacheWrapper(
161
+ self.prefill_wrappers_verify.append(
162
+ BatchPrefillWithPagedKVCacheWrapper(
202
163
  self.workspace_buffer,
203
164
  "NHD",
204
- use_tensor_cores=self.decode_use_tensor_cores,
205
165
  )
206
166
  )
167
+ self.decode_wrappers.append(
168
+ BatchDecodeWithPagedKVCacheWrapper(
169
+ self.workspace_buffer,
170
+ "NHD",
171
+ use_tensor_cores=self.decode_use_tensor_cores,
172
+ )
173
+ )
207
174
 
208
175
  # Create indices updater
209
176
  if not skip_prefill:
210
177
  self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
211
178
  model_runner, self
212
- )
179
+ ) # for verify
213
180
  self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
214
181
 
215
182
  # Other metadata
216
183
  self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
217
184
  self.decode_cuda_graph_metadata = {}
218
- self.prefill_cuda_graph_metadata = {}
185
+ self.prefill_cuda_graph_metadata = {} # For verify
186
+ self.draft_extend_cuda_graph_metadata = {} # For draft extend
219
187
 
220
188
  def init_forward_metadata(self, forward_batch: ForwardBatch):
221
189
  if forward_batch.forward_mode.is_decode_or_idle():
@@ -259,10 +227,7 @@ class FlashInferAttnBackend(AttentionBackend):
259
227
  else:
260
228
  prefix_lens = forward_batch.extend_prefix_lens
261
229
 
262
- if self.is_multimodal or (
263
- self.enable_flashinfer_mla
264
- and not global_server_args_dict["disable_radix_cache"]
265
- ):
230
+ if self.is_multimodal:
266
231
  use_ragged = False
267
232
  extend_no_prefix = False
268
233
  else:
@@ -316,37 +281,24 @@ class FlashInferAttnBackend(AttentionBackend):
316
281
  seq_lens: torch.Tensor,
317
282
  encoder_lens: Optional[torch.Tensor],
318
283
  forward_mode: ForwardMode,
319
- spec_info: Optional[SpecInfo],
284
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
320
285
  ):
321
286
  if forward_mode.is_decode_or_idle():
322
287
  decode_wrappers = []
323
288
  for i in range(self.num_wrappers):
324
- if self.enable_flashinfer_mla:
325
- decode_wrappers.append(
326
- BatchMLAPagedAttentionWrapper(
327
- self.workspace_buffer,
328
- use_cuda_graph=True,
329
- qo_indptr=self.qo_indptr[i][: num_tokens + 1],
330
- kv_indptr=self.kv_indptr[i][: num_tokens + 1],
331
- kv_indices=self.cuda_graph_kv_indices[i],
332
- kv_len_arr=self.kv_last_page_len[:num_tokens],
333
- backend="fa2",
334
- )
335
- )
336
- else:
337
- decode_wrappers.append(
338
- BatchDecodeWithPagedKVCacheWrapper(
339
- self.workspace_buffer,
340
- "NHD",
341
- use_cuda_graph=True,
342
- use_tensor_cores=self.decode_use_tensor_cores,
343
- paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
344
- paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
345
- paged_kv_last_page_len_buffer=self.kv_last_page_len[
346
- :num_tokens
347
- ],
348
- )
289
+ decode_wrappers.append(
290
+ BatchDecodeWithPagedKVCacheWrapper(
291
+ self.workspace_buffer,
292
+ "NHD",
293
+ use_cuda_graph=True,
294
+ use_tensor_cores=self.decode_use_tensor_cores,
295
+ paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
296
+ paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
297
+ paged_kv_last_page_len_buffer=self.kv_last_page_len[
298
+ :num_tokens
299
+ ],
349
300
  )
301
+ )
350
302
  seq_lens_sum = seq_lens.sum().item()
351
303
  self.indices_updater_decode.update(
352
304
  req_pool_indices,
@@ -358,6 +310,10 @@ class FlashInferAttnBackend(AttentionBackend):
358
310
  )
359
311
  self.decode_cuda_graph_metadata[bs] = decode_wrappers
360
312
  self.forward_metadata = DecodeMetadata(decode_wrappers)
313
+ for i in range(self.num_wrappers):
314
+ decode_wrappers[i].begin_forward = partial(
315
+ fast_decode_plan, decode_wrappers[i]
316
+ )
361
317
  elif forward_mode.is_target_verify():
362
318
  prefill_wrappers = []
363
319
  for i in range(self.num_wrappers):
@@ -398,7 +354,8 @@ class FlashInferAttnBackend(AttentionBackend):
398
354
  seq_lens_sum: int,
399
355
  encoder_lens: Optional[torch.Tensor],
400
356
  forward_mode: ForwardMode,
401
- spec_info: Optional[SpecInfo],
357
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
358
+ seq_lens_cpu: Optional[torch.Tensor],
402
359
  ):
403
360
  if forward_mode.is_decode_or_idle():
404
361
  self.indices_updater_decode.update(
@@ -435,114 +392,64 @@ class FlashInferAttnBackend(AttentionBackend):
435
392
  forward_batch: ForwardBatch,
436
393
  save_kv_cache=True,
437
394
  ):
438
- if global_config.enable_flashinfer_mla:
439
- cache_loc = (
440
- forward_batch.out_cache_loc
441
- if not layer.is_cross_attention
442
- else forward_batch.encoder_out_cache_loc
443
- )
444
-
445
- logits_soft_cap = layer.logit_cap
395
+ prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
396
+ self._get_wrapper_idx(layer)
397
+ ]
398
+ cache_loc = (
399
+ forward_batch.out_cache_loc
400
+ if not layer.is_cross_attention
401
+ else forward_batch.encoder_out_cache_loc
402
+ )
446
403
 
447
- if global_server_args_dict["disable_radix_cache"]:
448
- # use mla ragged prefill
449
- o, _ = self.prefill_wrapper_ragged.forward_return_lse(
450
- q.view(-1, layer.tp_q_head_num, layer.head_dim),
451
- k.view(-1, layer.tp_k_head_num, layer.head_dim),
452
- v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
453
- causal=True,
454
- sm_scale=layer.scaling,
455
- logits_soft_cap=logits_soft_cap,
456
- )
404
+ logits_soft_cap = layer.logit_cap
457
405
 
406
+ if not self.forward_metadata.use_ragged:
407
+ if k is not None:
408
+ assert v is not None
458
409
  if save_kv_cache:
459
410
  forward_batch.token_to_kv_pool.set_kv_buffer(
460
- layer,
461
- cache_loc,
462
- k,
463
- v,
411
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
464
412
  )
465
- else:
466
- # use mla paged prefill
467
- prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
468
- self._get_wrapper_idx(layer)
469
- ]
470
- if k is not None:
471
- assert v is not None
472
- if save_kv_cache:
473
- forward_batch.token_to_kv_pool.set_kv_buffer(
474
- layer, cache_loc, k, v
475
- )
476
- qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
477
- k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
478
-
479
- o = prefill_wrapper_paged.run(
480
- qall[:, :, : layer.v_head_dim],
481
- qall[:, :, layer.v_head_dim :],
482
- k_buf[:, :, : layer.v_head_dim],
483
- k_buf[:, :, layer.v_head_dim :],
484
- )
485
413
 
486
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
414
+ o = prefill_wrapper_paged.forward(
415
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
416
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
417
+ causal=not layer.is_cross_attention,
418
+ sm_scale=layer.scaling,
419
+ window_left=layer.sliding_window_size,
420
+ logits_soft_cap=logits_soft_cap,
421
+ k_scale=layer.k_scale,
422
+ v_scale=layer.v_scale,
423
+ )
487
424
  else:
488
- prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
489
- self._get_wrapper_idx(layer)
490
- ]
491
- cache_loc = (
492
- forward_batch.out_cache_loc
493
- if not layer.is_cross_attention
494
- else forward_batch.encoder_out_cache_loc
425
+ o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
426
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
427
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
428
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
429
+ causal=True,
430
+ sm_scale=layer.scaling,
431
+ logits_soft_cap=logits_soft_cap,
495
432
  )
496
433
 
497
- logits_soft_cap = layer.logit_cap
498
-
499
- if not self.forward_metadata.use_ragged:
500
- if k is not None:
501
- assert v is not None
502
- if save_kv_cache:
503
- forward_batch.token_to_kv_pool.set_kv_buffer(
504
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
505
- )
506
-
507
- o = prefill_wrapper_paged.forward(
434
+ if self.forward_metadata.extend_no_prefix:
435
+ o = o1
436
+ else:
437
+ o2, s2 = prefill_wrapper_paged.forward_return_lse(
508
438
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
509
439
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
510
- causal=not layer.is_cross_attention,
511
- sm_scale=layer.scaling,
512
- window_left=layer.sliding_window_size,
513
- logits_soft_cap=logits_soft_cap,
514
- k_scale=layer.k_scale,
515
- v_scale=layer.v_scale,
516
- )
517
- else:
518
- o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
519
- q.view(-1, layer.tp_q_head_num, layer.head_dim),
520
- k.view(-1, layer.tp_k_head_num, layer.head_dim),
521
- v.view(-1, layer.tp_v_head_num, layer.head_dim),
522
- causal=True,
440
+ causal=False,
523
441
  sm_scale=layer.scaling,
524
442
  logits_soft_cap=logits_soft_cap,
525
443
  )
526
444
 
527
- if self.forward_metadata.extend_no_prefix:
528
- o = o1
529
- else:
530
- o2, s2 = prefill_wrapper_paged.forward_return_lse(
531
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
532
- forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
533
- causal=False,
534
- sm_scale=layer.scaling,
535
- logits_soft_cap=layer.logit_cap,
536
- )
537
-
538
- o, _ = merge_state(o1, s1, o2, s2)
445
+ o, _ = merge_state(o1, s1, o2, s2)
539
446
 
540
- if save_kv_cache:
541
- forward_batch.token_to_kv_pool.set_kv_buffer(
542
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
543
- )
447
+ if save_kv_cache:
448
+ forward_batch.token_to_kv_pool.set_kv_buffer(
449
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
450
+ )
544
451
 
545
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
452
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
546
453
 
547
454
  def forward_decode(
548
455
  self,
@@ -562,45 +469,23 @@ class FlashInferAttnBackend(AttentionBackend):
562
469
  else forward_batch.encoder_out_cache_loc
563
470
  )
564
471
 
565
- if self.enable_flashinfer_mla:
566
- if k is not None:
567
- assert v is not None
568
- if save_kv_cache:
569
- forward_batch.token_to_kv_pool.set_kv_buffer(
570
- layer,
571
- cache_loc,
572
- k,
573
- v,
574
- )
575
- reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
576
- k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
577
- reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
578
- o = decode_wrapper.run(
579
- reshaped_q[:, :, : layer.v_head_dim],
580
- reshaped_q[:, :, layer.v_head_dim :],
581
- reshaped_k[:, :, : layer.v_head_dim],
582
- reshaped_k[:, :, layer.v_head_dim :],
583
- )
584
-
585
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
586
- else:
587
- if k is not None:
588
- assert v is not None
589
- if save_kv_cache:
590
- forward_batch.token_to_kv_pool.set_kv_buffer(
591
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
592
- )
472
+ if k is not None:
473
+ assert v is not None
474
+ if save_kv_cache:
475
+ forward_batch.token_to_kv_pool.set_kv_buffer(
476
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
477
+ )
593
478
 
594
- o = decode_wrapper.forward(
595
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
596
- forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
597
- sm_scale=layer.scaling,
598
- logits_soft_cap=layer.logit_cap,
599
- k_scale=layer.k_scale,
600
- v_scale=layer.v_scale,
601
- )
479
+ o = decode_wrapper.forward(
480
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
481
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
482
+ sm_scale=layer.scaling,
483
+ logits_soft_cap=layer.logit_cap,
484
+ k_scale=layer.k_scale,
485
+ v_scale=layer.v_scale,
486
+ )
602
487
 
603
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
488
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
604
489
 
605
490
  def _get_wrapper_idx(self, layer: RadixAttention):
606
491
  if self.num_wrappers == 1:
@@ -648,11 +533,9 @@ class FlashInferIndicesUpdaterDecode:
648
533
  req_pool_indices: torch.Tensor,
649
534
  seq_lens: torch.Tensor,
650
535
  seq_lens_sum: int,
651
- decode_wrappers: List[
652
- Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
653
- ],
536
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
654
537
  encoder_lens: Optional[torch.Tensor],
655
- spec_info: Optional[SpecInfo],
538
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
656
539
  ):
657
540
  # Keep the signature for type checking. It will be assigned during runtime.
658
541
  raise NotImplementedError()
@@ -662,11 +545,9 @@ class FlashInferIndicesUpdaterDecode:
662
545
  req_pool_indices: torch.Tensor,
663
546
  seq_lens: torch.Tensor,
664
547
  seq_lens_sum: int,
665
- decode_wrappers: List[
666
- Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
667
- ],
548
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
668
549
  encoder_lens: Optional[torch.Tensor],
669
- spec_info: Optional[SpecInfo],
550
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
670
551
  ):
671
552
  decode_wrappers = decode_wrappers or self.decode_wrappers
672
553
  self.call_begin_forward(
@@ -686,7 +567,7 @@ class FlashInferIndicesUpdaterDecode:
686
567
  seq_lens_sum: int,
687
568
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
688
569
  encoder_lens: Optional[torch.Tensor],
689
- spec_info: Optional[SpecInfo],
570
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
690
571
  ):
691
572
  for wrapper_id in range(2):
692
573
  if wrapper_id == 0:
@@ -720,7 +601,7 @@ class FlashInferIndicesUpdaterDecode:
720
601
  seq_lens_sum: int,
721
602
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
722
603
  encoder_lens: Optional[torch.Tensor],
723
- spec_info: Optional[SpecInfo],
604
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
724
605
  ):
725
606
  for wrapper_id in range(2):
726
607
  if wrapper_id == 0:
@@ -745,23 +626,27 @@ class FlashInferIndicesUpdaterDecode:
745
626
 
746
627
  def call_begin_forward(
747
628
  self,
748
- wrapper: Union[
749
- BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
750
- ],
629
+ wrapper: BatchDecodeWithPagedKVCacheWrapper,
751
630
  req_pool_indices: torch.Tensor,
752
631
  paged_kernel_lens: torch.Tensor,
753
632
  paged_kernel_lens_sum: int,
754
633
  kv_indptr: torch.Tensor,
755
634
  kv_start_idx: torch.Tensor,
756
- spec_info: Optional[SpecInfo],
635
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
757
636
  ):
758
637
  if spec_info is None:
759
638
  bs = len(req_pool_indices)
760
639
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
761
640
  kv_indptr = kv_indptr[: bs + 1]
762
- kv_indices = torch.empty(
763
- paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
764
- )
641
+
642
+ if wrapper.is_cuda_graph_enabled:
643
+ # Directly write to the cuda graph input buffer
644
+ kv_indices = wrapper._paged_kv_indices_buf
645
+ else:
646
+ kv_indices = torch.empty(
647
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
648
+ )
649
+
765
650
  create_flashinfer_kv_indices_triton[(bs,)](
766
651
  self.req_to_token,
767
652
  req_pool_indices,
@@ -775,37 +660,18 @@ class FlashInferIndicesUpdaterDecode:
775
660
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
776
661
  bs = kv_indptr.shape[0] - 1
777
662
 
778
- if global_config.enable_flashinfer_mla:
779
- sm_scale = 1.0 / math.sqrt(192)
780
- q_indptr = torch.arange(0, bs + 1).to(0).int()
781
- kv_lens = paged_kernel_lens.to(torch.int32)
782
- wrapper.plan(
783
- q_indptr,
784
- kv_indptr,
785
- kv_indices,
786
- kv_lens,
787
- self.num_qo_heads,
788
- 512,
789
- 64,
790
- 1,
791
- False,
792
- sm_scale,
793
- self.data_type,
794
- self.data_type,
795
- )
796
- else:
797
- wrapper.begin_forward(
798
- kv_indptr,
799
- kv_indices,
800
- self.kv_last_page_len[:bs],
801
- self.num_qo_heads,
802
- self.num_kv_heads,
803
- self.head_dim,
804
- 1,
805
- data_type=self.data_type,
806
- q_data_type=self.q_data_type,
807
- non_blocking=True,
808
- )
663
+ wrapper.begin_forward(
664
+ kv_indptr,
665
+ kv_indices,
666
+ self.kv_last_page_len[:bs],
667
+ self.num_qo_heads,
668
+ self.num_kv_heads,
669
+ self.head_dim,
670
+ 1,
671
+ data_type=self.data_type,
672
+ q_data_type=self.q_data_type,
673
+ non_blocking=True,
674
+ )
809
675
 
810
676
 
811
677
  class FlashInferIndicesUpdaterPrefill:
@@ -841,32 +707,28 @@ class FlashInferIndicesUpdaterPrefill:
841
707
 
842
708
  def update(
843
709
  self,
844
- req_pool_indices: torch.Tnesor,
710
+ req_pool_indices: torch.Tensor,
845
711
  seq_lens: torch.Tensor,
846
712
  seq_lens_sum: int,
847
713
  prefix_lens: torch.Tensor,
848
- prefill_wrappers: List[
849
- Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
850
- ],
714
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
851
715
  use_ragged: bool,
852
716
  encoder_lens: Optional[torch.Tensor],
853
- spec_info: Optional[SpecInfo],
717
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
854
718
  ):
855
719
  # Keep the signature for type checking. It will be assigned during runtime.
856
720
  raise NotImplementedError()
857
721
 
858
722
  def update_single_wrapper(
859
723
  self,
860
- req_pool_indices: torch.Tnesor,
724
+ req_pool_indices: torch.Tensor,
861
725
  seq_lens: torch.Tensor,
862
726
  seq_lens_sum: int,
863
727
  prefix_lens: torch.Tensor,
864
- prefill_wrappers: List[
865
- Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
866
- ],
728
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
867
729
  use_ragged: bool,
868
730
  encoder_lens: Optional[torch.Tensor],
869
- spec_info: Optional[SpecInfo],
731
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
870
732
  ):
871
733
  if use_ragged:
872
734
  paged_kernel_lens = prefix_lens
@@ -899,7 +761,7 @@ class FlashInferIndicesUpdaterPrefill:
899
761
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
900
762
  use_ragged: bool,
901
763
  encoder_lens: Optional[torch.Tensor],
902
- spec_info: Optional[SpecInfo],
764
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
903
765
  ):
904
766
  for wrapper_id in range(2):
905
767
  if wrapper_id == 0:
@@ -940,7 +802,7 @@ class FlashInferIndicesUpdaterPrefill:
940
802
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
941
803
  use_ragged: bool,
942
804
  encoder_lens: Optional[torch.Tensor],
943
- spec_info: Optional[SpecInfo],
805
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
944
806
  ):
945
807
  for wrapper_id in range(2):
946
808
  if wrapper_id == 0:
@@ -972,9 +834,7 @@ class FlashInferIndicesUpdaterPrefill:
972
834
  def call_begin_forward(
973
835
  self,
974
836
  wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
975
- wrapper_paged: Union[
976
- BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
977
- ],
837
+ wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
978
838
  req_pool_indices: torch.Tensor,
979
839
  paged_kernel_lens: torch.Tensor,
980
840
  paged_kernel_lens_sum: int,
@@ -984,10 +844,11 @@ class FlashInferIndicesUpdaterPrefill:
984
844
  kv_indptr: torch.Tensor,
985
845
  qo_indptr: torch.Tensor,
986
846
  use_ragged: bool,
987
- spec_info: Optional[SpecInfo],
847
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
988
848
  ):
989
- bs = len(req_pool_indices)
849
+ bs = len(seq_lens)
990
850
  if spec_info is None:
851
+ assert len(seq_lens) == len(req_pool_indices)
991
852
  # Normal extend
992
853
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
993
854
  kv_indptr = kv_indptr[: bs + 1]
@@ -1005,77 +866,54 @@ class FlashInferIndicesUpdaterPrefill:
1005
866
  kv_indices,
1006
867
  self.req_to_token.shape[1],
1007
868
  )
1008
-
1009
869
  qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
1010
870
  qo_indptr = qo_indptr[: bs + 1]
1011
871
  custom_mask = None
1012
872
  else:
873
+ assert isinstance(spec_info, EagleDraftInput) or isinstance(
874
+ spec_info, EagleVerifyInput
875
+ )
1013
876
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
1014
877
  spec_info.generate_attn_arg_prefill(
1015
878
  req_pool_indices,
1016
879
  paged_kernel_lens,
880
+ paged_kernel_lens_sum,
1017
881
  self.req_to_token,
1018
882
  )
1019
883
  )
1020
884
 
1021
885
  # extend part
1022
886
  if use_ragged:
1023
- if global_config.enable_flashinfer_mla:
1024
- wrapper_ragged.begin_forward(
1025
- qo_indptr=qo_indptr,
1026
- kv_indptr=qo_indptr,
1027
- num_qo_heads=self.num_qo_heads,
1028
- num_kv_heads=self.num_kv_heads,
1029
- head_dim_qk=192,
1030
- head_dim_vo=128,
1031
- q_data_type=self.q_data_type,
1032
- )
1033
- else:
1034
- wrapper_ragged.begin_forward(
1035
- qo_indptr,
1036
- qo_indptr,
1037
- self.num_qo_heads,
1038
- self.num_kv_heads,
1039
- self.head_dim,
1040
- q_data_type=self.q_data_type,
1041
- )
1042
-
1043
- if not global_config.enable_flashinfer_mla:
1044
- # cached part
1045
- wrapper_paged.begin_forward(
887
+ wrapper_ragged.begin_forward(
888
+ qo_indptr,
1046
889
  qo_indptr,
1047
- kv_indptr,
1048
- kv_indices,
1049
- self.kv_last_page_len[:bs],
1050
890
  self.num_qo_heads,
1051
891
  self.num_kv_heads,
1052
892
  self.head_dim,
1053
- 1,
1054
893
  q_data_type=self.q_data_type,
1055
- custom_mask=custom_mask,
1056
- non_blocking=True,
1057
- )
1058
- elif (
1059
- global_config.enable_flashinfer_mla
1060
- and not global_server_args_dict["disable_radix_cache"]
1061
- ):
1062
- # mla paged prefill
1063
- kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
1064
- wrapper_paged.plan(
1065
- qo_indptr,
1066
- kv_indptr,
1067
- kv_indices,
1068
- kv_len_arr,
1069
- self.num_qo_heads,
1070
- 512,
1071
- 64,
1072
- 1,
1073
- True,
1074
- 1 / math.sqrt(192),
1075
- self.data_type,
1076
- self.data_type,
1077
894
  )
1078
895
 
896
+ # cached part
897
+ wrapper_paged.begin_forward(
898
+ qo_indptr,
899
+ kv_indptr,
900
+ kv_indices,
901
+ self.kv_last_page_len[:bs],
902
+ self.num_qo_heads,
903
+ self.num_kv_heads,
904
+ self.head_dim,
905
+ 1,
906
+ q_data_type=self.q_data_type,
907
+ kv_data_type=self.data_type,
908
+ custom_mask=custom_mask,
909
+ non_blocking=True,
910
+ )
911
+
912
+
913
+ # Use as a fast path to override the indptr in flashinfer's plan function
914
+ # This is used to remove some host-to-device copy overhead.
915
+ global global_override_indptr_cpu
916
+
1079
917
 
1080
918
  class FlashInferMultiStepDraftBackend:
1081
919
  """
@@ -1094,7 +932,8 @@ class FlashInferMultiStepDraftBackend:
1094
932
  self.topk = topk
1095
933
  self.speculative_num_steps = speculative_num_steps
1096
934
  self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
1097
- max_bs = model_runner.req_to_token_pool.size
935
+
936
+ max_bs = model_runner.req_to_token_pool.size * self.topk
1098
937
  self.kv_indptr = torch.zeros(
1099
938
  (
1100
939
  self.speculative_num_steps,
@@ -1103,6 +942,9 @@ class FlashInferMultiStepDraftBackend:
1103
942
  dtype=torch.int32,
1104
943
  device=model_runner.device,
1105
944
  )
945
+ self.kv_last_page_len = torch.ones(
946
+ (max_bs,), dtype=torch.int32, device=model_runner.device
947
+ )
1106
948
  self.attn_backends = []
1107
949
  for i in range(self.speculative_num_steps):
1108
950
  self.attn_backends.append(
@@ -1110,14 +952,20 @@ class FlashInferMultiStepDraftBackend:
1110
952
  model_runner,
1111
953
  skip_prefill=True,
1112
954
  kv_indptr_buf=self.kv_indptr[i],
955
+ kv_last_page_len_buf=self.kv_last_page_len,
1113
956
  )
1114
957
  )
958
+
1115
959
  self.max_context_len = self.attn_backends[0].max_context_len
960
+
1116
961
  # Cached variables for generate_draft_decode_kv_indices
1117
962
  self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
1118
963
 
1119
964
  def common_template(
1120
- self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
965
+ self,
966
+ forward_batch: ForwardBatch,
967
+ kv_indices_buffer: torch.Tensor,
968
+ call_fn: Callable,
1121
969
  ):
1122
970
  num_seqs = forward_batch.batch_size
1123
971
  bs = self.topk * num_seqs
@@ -1142,13 +990,23 @@ class FlashInferMultiStepDraftBackend:
1142
990
  triton.next_power_of_2(bs),
1143
991
  )
1144
992
 
993
+ assert forward_batch.spec_info is not None
994
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
995
+
996
+ # Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
997
+ indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
998
+ global global_override_indptr_cpu
999
+
1145
1000
  for i in range(self.speculative_num_steps - 1):
1146
1001
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
1147
1002
  forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
1148
1003
  : seq_lens_sum * self.topk + bs * (i + 1)
1149
1004
  ]
1005
+ global_override_indptr_cpu = indptr_cpu_whole[i]
1150
1006
  call_fn(i, forward_batch)
1151
1007
 
1008
+ global_override_indptr_cpu = None
1009
+
1152
1010
  def init_forward_metadata(self, forward_batch: ForwardBatch):
1153
1011
  kv_indices = torch.zeros(
1154
1012
  (
@@ -1160,6 +1018,8 @@ class FlashInferMultiStepDraftBackend:
1160
1018
  )
1161
1019
 
1162
1020
  def call_fn(i, forward_batch):
1021
+ assert forward_batch.spec_info is not None
1022
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
1163
1023
  forward_batch.spec_info.kv_indptr = (
1164
1024
  forward_batch.spec_info.kv_indptr.clone()
1165
1025
  )
@@ -1176,6 +1036,7 @@ class FlashInferMultiStepDraftBackend:
1176
1036
  dtype=torch.int32,
1177
1037
  device="cuda",
1178
1038
  )
1039
+
1179
1040
  for i in range(self.speculative_num_steps):
1180
1041
  self.attn_backends[i].init_cuda_graph_state(
1181
1042
  max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
@@ -1192,65 +1053,27 @@ class FlashInferMultiStepDraftBackend:
1192
1053
  forward_mode=ForwardMode.DECODE,
1193
1054
  spec_info=forward_batch.spec_info,
1194
1055
  )
1195
- decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
1196
- forward_batch.batch_size
1197
- ][0]
1198
- decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
1199
1056
 
1200
1057
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1201
1058
 
1202
- def init_forward_metadata_replay_cuda_graph(self, forward_batch):
1059
+ def init_forward_metadata_replay_cuda_graph(
1060
+ self, forward_batch: ForwardBatch, bs: int
1061
+ ):
1203
1062
  def call_fn(i, forward_batch):
1204
1063
  self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1205
- forward_batch.batch_size,
1064
+ bs,
1206
1065
  forward_batch.req_pool_indices,
1207
1066
  forward_batch.seq_lens,
1208
1067
  seq_lens_sum=-1,
1209
1068
  encoder_lens=None,
1210
1069
  forward_mode=ForwardMode.DECODE,
1211
1070
  spec_info=forward_batch.spec_info,
1071
+ seq_lens_cpu=None,
1212
1072
  )
1213
1073
 
1214
1074
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1215
1075
 
1216
1076
 
1217
- @triton.jit
1218
- def create_flashinfer_kv_indices_triton(
1219
- req_to_token_ptr, # [max_batch, max_context_len]
1220
- req_pool_indices_ptr,
1221
- page_kernel_lens_ptr,
1222
- kv_indptr,
1223
- kv_start_idx,
1224
- kv_indices_ptr,
1225
- req_to_token_ptr_stride: tl.constexpr,
1226
- ):
1227
- BLOCK_SIZE: tl.constexpr = 512
1228
- pid = tl.program_id(axis=0)
1229
-
1230
- req_pool_index = tl.load(req_pool_indices_ptr + pid)
1231
- kv_indices_offset = tl.load(kv_indptr + pid)
1232
-
1233
- kv_start = 0
1234
- kv_end = 0
1235
- if kv_start_idx:
1236
- kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
1237
- kv_end = kv_start
1238
- kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
1239
-
1240
- num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
1241
- for i in range(num_loop):
1242
- offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
1243
- mask = offset < kv_end - kv_start
1244
- data = tl.load(
1245
- req_to_token_ptr
1246
- + req_pool_index * req_to_token_ptr_stride
1247
- + kv_start
1248
- + offset,
1249
- mask=mask,
1250
- )
1251
- tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
1252
-
1253
-
1254
1077
  def should_use_tensor_core(
1255
1078
  kv_cache_dtype: torch.dtype,
1256
1079
  num_attention_heads: int,
@@ -1272,6 +1095,21 @@ def should_use_tensor_core(
1272
1095
  if env_override is not None:
1273
1096
  return env_override.lower() == "true"
1274
1097
 
1098
+ # Try to use _grouped_size_compiled_for_decode_kernels if available
1099
+ # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
1100
+ try:
1101
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
1102
+
1103
+ if not _grouped_size_compiled_for_decode_kernels(
1104
+ num_attention_heads,
1105
+ num_kv_heads,
1106
+ ):
1107
+ return True
1108
+ else:
1109
+ return False
1110
+ except (ImportError, AttributeError):
1111
+ pass
1112
+
1275
1113
  # Calculate GQA group size
1276
1114
  gqa_group_size = num_attention_heads // num_kv_heads
1277
1115
 
@@ -1284,6 +1122,11 @@ def should_use_tensor_core(
1284
1122
  return False
1285
1123
 
1286
1124
 
1125
+ # Use as a fast path to override the indptr in flashinfer's plan function
1126
+ # This is used to remove some host-to-device copy overhead.
1127
+ global_override_indptr_cpu = None
1128
+
1129
+
1287
1130
  def fast_decode_plan(
1288
1131
  self,
1289
1132
  indptr: torch.Tensor,
@@ -1301,12 +1144,21 @@ def fast_decode_plan(
1301
1144
  sm_scale: Optional[float] = None,
1302
1145
  rope_scale: Optional[float] = None,
1303
1146
  rope_theta: Optional[float] = None,
1304
- **kwargs,
1147
+ non_blocking: bool = True,
1305
1148
  ) -> None:
1306
- """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
1149
+ """
1150
+ A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
1151
+ Modifications:
1152
+ - Remove unnecessary device-to-device copy for the cuda graph buffers.
1153
+ - Remove unnecessary host-to-device copy for the metadata buffers.
1154
+ """
1307
1155
  batch_size = len(last_page_len)
1308
1156
  if logits_soft_cap is None:
1309
1157
  logits_soft_cap = 0.0
1158
+
1159
+ if self.use_tensor_cores:
1160
+ qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
1161
+
1310
1162
  if self.is_cuda_graph_enabled:
1311
1163
  if batch_size != self._fixed_batch_size:
1312
1164
  raise ValueError(
@@ -1319,13 +1171,20 @@ def fast_decode_plan(
1319
1171
  raise ValueError(
1320
1172
  "The size of indices should be less than or equal to the allocated buffer"
1321
1173
  )
1174
+ # Skip these copies because we directly write to them during prepartion
1175
+ # self._paged_kv_indptr_buf.copy_(indptr)
1176
+ # self._paged_kv_indices_buf[: len(indices)] = indices
1177
+ # self._paged_kv_last_page_len_buf.copy_(last_page_len)
1322
1178
  else:
1323
1179
  self._paged_kv_indptr_buf = indptr
1324
1180
  self._paged_kv_indices_buf = indices
1325
1181
  self._paged_kv_last_page_len_buf = last_page_len
1182
+ self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
1183
+
1326
1184
  # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
1327
1185
  if not q_data_type:
1328
1186
  q_data_type = data_type
1187
+
1329
1188
  if not hasattr(self, "empty_q_data"):
1330
1189
  self.empty_q_data = torch.empty(
1331
1190
  0,
@@ -1342,27 +1201,56 @@ def fast_decode_plan(
1342
1201
  ),
1343
1202
  )
1344
1203
  self.last_page_len = torch.ones(32768, dtype=torch.int32)
1345
- empty_q_data = self.empty_q_data
1346
- empty_kv_cache = self.empty_kv_cache
1347
- stream = torch.cuda.current_stream()
1348
- self._cached_module.plan(
1349
- self._float_workspace_buffer,
1350
- self._int_workspace_buffer,
1351
- self._pin_memory_int_workspace_buffer,
1352
- indptr.to("cpu"),
1353
- batch_size,
1354
- num_qo_heads,
1355
- num_kv_heads,
1356
- page_size,
1357
- self.is_cuda_graph_enabled,
1358
- window_left,
1359
- logits_soft_cap,
1360
- head_dim,
1361
- head_dim,
1362
- empty_q_data,
1363
- empty_kv_cache,
1364
- stream.cuda_stream,
1204
+
1205
+ indptr_host = (
1206
+ global_override_indptr_cpu
1207
+ if global_override_indptr_cpu is not None
1208
+ else indptr.cpu()
1365
1209
  )
1210
+
1211
+ if self.use_tensor_cores:
1212
+ kv_lens_arr_host = get_seq_lens(
1213
+ indptr_host, self.last_page_len[:batch_size], page_size
1214
+ )
1215
+
1216
+ self._plan_info = self._cached_module.plan(
1217
+ self._float_workspace_buffer,
1218
+ self._int_workspace_buffer,
1219
+ self._pin_memory_int_workspace_buffer,
1220
+ qo_indptr_host,
1221
+ indptr_host,
1222
+ kv_lens_arr_host,
1223
+ batch_size, # total_num_rows
1224
+ batch_size,
1225
+ num_qo_heads,
1226
+ num_kv_heads,
1227
+ page_size,
1228
+ self.is_cuda_graph_enabled,
1229
+ head_dim,
1230
+ head_dim,
1231
+ False, # causal
1232
+ torch.cuda.current_stream().cuda_stream,
1233
+ )
1234
+ else:
1235
+ self._plan_info = self._cached_module.plan(
1236
+ self._float_workspace_buffer,
1237
+ self._int_workspace_buffer,
1238
+ self._pin_memory_int_workspace_buffer,
1239
+ indptr_host,
1240
+ batch_size,
1241
+ num_qo_heads,
1242
+ num_kv_heads,
1243
+ page_size,
1244
+ self.is_cuda_graph_enabled,
1245
+ window_left,
1246
+ logits_soft_cap,
1247
+ head_dim,
1248
+ head_dim,
1249
+ self.empty_q_data,
1250
+ self.empty_kv_cache,
1251
+ torch.cuda.current_stream().cuda_stream,
1252
+ )
1253
+
1366
1254
  self._pos_encoding_mode = pos_encoding_mode
1367
1255
  self._window_left = window_left
1368
1256
  self._logits_soft_cap = logits_soft_cap