sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -24,9 +24,7 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
24
24
 
25
25
  from sglang.global_config import global_config
26
26
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
27
- from sglang.srt.layers.attention.flashinfer_backend import (
28
- create_flashinfer_kv_indices_triton,
29
- )
27
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
30
28
  from sglang.srt.layers.dp_attention import get_attention_tp_size
31
29
  from sglang.srt.layers.utils import is_sm100_supported
32
30
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -61,6 +59,115 @@ class PrefillMetadata:
61
59
  global_workspace_buffer = None
62
60
 
63
61
 
62
+ class FlashInferMhaChunkKVRunner:
63
+ def __init__(
64
+ self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
65
+ ):
66
+ # Parse Constants
67
+ self.num_local_heads = (
68
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
69
+ )
70
+ self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
71
+ self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
72
+ self.v_head_dim = model_runner.model_config.v_head_dim
73
+ self.data_type = model_runner.dtype
74
+ self.q_data_type = model_runner.dtype
75
+
76
+ # Buffers and wrappers
77
+ self.qo_indptr = attn_backend.qo_indptr
78
+ self.workspace_buffer = attn_backend.workspace_buffer
79
+ self.fmha_backend = attn_backend.fmha_backend
80
+
81
+ self.chunk_ragged_wrappers = []
82
+ self.ragged_wrapper = attn_backend.prefill_wrapper_ragged
83
+
84
+ def update_prefix_chunks(self, num_prefix_chunks: int):
85
+ while num_prefix_chunks > len(self.chunk_ragged_wrappers):
86
+ ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper(
87
+ self.workspace_buffer, "NHD", backend=self.fmha_backend
88
+ )
89
+ self.chunk_ragged_wrappers.append(ragged_wrapper)
90
+
91
+ def update_wrapper(
92
+ self,
93
+ forward_batch: ForwardBatch,
94
+ ):
95
+ assert forward_batch.num_prefix_chunks is not None
96
+ num_prefix_chunks = forward_batch.num_prefix_chunks
97
+ self.update_prefix_chunks(num_prefix_chunks)
98
+
99
+ prefix_lens = forward_batch.extend_prefix_lens
100
+ seq_lens = forward_batch.seq_lens
101
+
102
+ bs = len(seq_lens)
103
+ qo_indptr = self.qo_indptr
104
+ qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
105
+ qo_indptr = qo_indptr[: bs + 1]
106
+
107
+ for chunk_idx in range(forward_batch.num_prefix_chunks):
108
+ # MHA for chunked prefix kv cache when running model with MLA
109
+ assert forward_batch.prefix_chunk_idx is not None
110
+ assert forward_batch.prefix_chunk_cu_seq_lens is not None
111
+ assert forward_batch.prefix_chunk_max_seq_lens is not None
112
+
113
+ kv_indptr = forward_batch.prefix_chunk_cu_seq_lens[chunk_idx]
114
+ wrapper = self.chunk_ragged_wrappers[chunk_idx]
115
+ wrapper.begin_forward(
116
+ qo_indptr=qo_indptr,
117
+ kv_indptr=kv_indptr,
118
+ num_qo_heads=self.num_local_heads,
119
+ num_kv_heads=self.num_local_heads,
120
+ head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
121
+ head_dim_vo=self.v_head_dim,
122
+ q_data_type=self.q_data_type,
123
+ causal=False,
124
+ )
125
+ # ragged prefill
126
+ self.ragged_wrapper.begin_forward(
127
+ qo_indptr=qo_indptr,
128
+ kv_indptr=qo_indptr,
129
+ num_qo_heads=self.num_local_heads,
130
+ num_kv_heads=self.num_local_heads,
131
+ head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
132
+ head_dim_vo=self.v_head_dim,
133
+ q_data_type=self.q_data_type,
134
+ causal=True,
135
+ )
136
+
137
+ def forward(
138
+ self,
139
+ q: torch.Tensor,
140
+ k: torch.Tensor,
141
+ v: torch.Tensor,
142
+ layer: RadixAttention,
143
+ forward_batch: ForwardBatch,
144
+ ):
145
+ logits_soft_cap = layer.logit_cap
146
+ if forward_batch.attn_attend_prefix_cache:
147
+ chunk_idx = forward_batch.prefix_chunk_idx
148
+ assert chunk_idx >= 0
149
+ wrapper = self.chunk_ragged_wrappers[chunk_idx]
150
+ o1, s1 = wrapper.forward_return_lse(
151
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
152
+ k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
153
+ v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
154
+ causal=False,
155
+ sm_scale=layer.scaling,
156
+ logits_soft_cap=logits_soft_cap,
157
+ )
158
+ else:
159
+ o1, s1 = self.ragged_wrapper.forward_return_lse(
160
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
161
+ k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
162
+ v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
163
+ causal=True,
164
+ sm_scale=layer.scaling,
165
+ logits_soft_cap=logits_soft_cap,
166
+ )
167
+
168
+ return o1, s1
169
+
170
+
64
171
  class FlashInferMLAAttnBackend(AttentionBackend):
65
172
  """Flashinfer attention kernels."""
66
173
 
@@ -72,15 +179,22 @@ class FlashInferMLAAttnBackend(AttentionBackend):
72
179
  q_indptr_decode_buf: Optional[torch.Tensor] = None,
73
180
  ):
74
181
  super().__init__()
75
-
76
182
  # Parse constants
77
183
  self.max_context_len = model_runner.model_config.context_len
78
184
  self.device = model_runner.device
79
185
  self.skip_prefill = skip_prefill
186
+ self.enable_chunk_kv = (
187
+ not skip_prefill
188
+ and global_server_args_dict["disaggregation_mode"] != "decode"
189
+ and not global_server_args_dict["disable_chunked_prefix_cache"]
190
+ and not global_server_args_dict["flashinfer_mla_disable_ragged"]
191
+ )
192
+ self.page_size = model_runner.page_size
80
193
 
81
194
  # Allocate buffers
82
195
  global global_workspace_buffer
83
196
  if global_workspace_buffer is None:
197
+ # different from flashinfer zero_init_global_workspace_buffer
84
198
  global_workspace_buffer = torch.empty(
85
199
  global_config.flashinfer_workspace_size,
86
200
  dtype=torch.uint8,
@@ -96,23 +210,33 @@ class FlashInferMLAAttnBackend(AttentionBackend):
96
210
  else:
97
211
  self.kv_indptr = kv_indptr_buf
98
212
 
213
+ self.kv_indices = torch.empty(
214
+ (max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
215
+ dtype=torch.int32,
216
+ device=model_runner.device,
217
+ )
218
+
99
219
  if not self.skip_prefill:
100
220
  self.qo_indptr = torch.zeros(
101
221
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
102
222
  )
103
223
 
104
224
  if q_indptr_decode_buf is None:
225
+ # A hack to pre-initialize large batch size for dp attention
226
+ if model_runner.server_args.enable_dp_attention:
227
+ max_bs = model_runner.server_args.dp_size * max_bs
105
228
  self.q_indptr_decode = torch.arange(
106
229
  0, max_bs + 1, dtype=torch.int32, device=model_runner.device
107
230
  )
231
+
108
232
  else:
109
233
  self.q_indptr_decode = q_indptr_decode_buf
110
234
 
111
- fmha_backend = "auto"
235
+ self.fmha_backend = "auto"
112
236
  if is_sm100_supported():
113
- fmha_backend = "cutlass"
237
+ self.fmha_backend = "cutlass"
114
238
  self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
115
- self.workspace_buffer, "NHD", backend=fmha_backend
239
+ self.workspace_buffer, "NHD", backend=self.fmha_backend
116
240
  )
117
241
 
118
242
  if not self.skip_prefill:
@@ -136,6 +260,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
136
260
  self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
137
261
  model_runner, self
138
262
  )
263
+ if self.enable_chunk_kv:
264
+ self.mha_chunk_kv_cache = FlashInferMhaChunkKVRunner(model_runner, self)
139
265
 
140
266
  self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
141
267
  model_runner, self
@@ -147,6 +273,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
147
273
  self.prefill_cuda_graph_metadata = {} # For verify
148
274
 
149
275
  def init_forward_metadata(self, forward_batch: ForwardBatch):
276
+
150
277
  if forward_batch.forward_mode.is_decode_or_idle():
151
278
  self.indices_updater_decode.update(
152
279
  forward_batch.req_pool_indices,
@@ -204,16 +331,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
204
331
  max_num_tokens: int,
205
332
  kv_indices_buf: Optional[torch.Tensor] = None,
206
333
  ):
207
- if kv_indices_buf is None:
208
- cuda_graph_kv_indices = torch.zeros(
209
- (max_bs * self.max_context_len,),
210
- dtype=torch.int32,
211
- device="cuda",
212
- )
213
- else:
214
- cuda_graph_kv_indices = kv_indices_buf
215
-
216
- self.cuda_graph_kv_indices = cuda_graph_kv_indices
334
+ self.cuda_graph_kv_indices = (
335
+ self.kv_indices.clone() if kv_indices_buf is None else kv_indices_buf
336
+ )
217
337
  self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
218
338
  self.cuda_graph_kv_indptr = self.kv_indptr.clone()
219
339
  self.cuda_graph_kv_lens = torch.ones(
@@ -239,6 +359,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
239
359
  forward_mode: ForwardMode,
240
360
  spec_info: Optional[SpecInfo],
241
361
  ):
362
+
242
363
  if forward_mode.is_decode_or_idle():
243
364
  decode_wrapper = BatchMLAPagedAttentionWrapper(
244
365
  self.workspace_buffer,
@@ -249,7 +370,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
249
370
  kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
250
371
  backend="auto",
251
372
  )
252
-
253
373
  seq_lens_sum = seq_lens.sum().item()
254
374
  self.indices_updater_decode.update(
255
375
  req_pool_indices,
@@ -320,11 +440,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
320
440
  spec_info: Optional[SpecInfo],
321
441
  seq_lens_cpu: Optional[torch.Tensor],
322
442
  ):
443
+
323
444
  if forward_mode.is_decode_or_idle():
324
445
  assert seq_lens_cpu is not None
325
446
  kv_len_arr_cpu = seq_lens_cpu[:bs]
447
+ num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
326
448
  self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
327
- kv_len_arr_cpu, dim=0
449
+ num_pages_per_req, dim=0
328
450
  )
329
451
  self.fast_decode_kwargs.update(
330
452
  {
@@ -333,7 +455,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
333
455
  "kv_len_arr_cpu": kv_len_arr_cpu,
334
456
  }
335
457
  )
336
-
337
458
  self.indices_updater_decode.update(
338
459
  req_pool_indices[:bs],
339
460
  seq_lens[:bs],
@@ -369,6 +490,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
369
490
  def get_cuda_graph_seq_len_fill_value(self):
370
491
  return 1
371
492
 
493
+ def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
494
+ """Init the metadata for a forward pass."""
495
+ self.mha_chunk_kv_cache.update_wrapper(forward_batch)
496
+
372
497
  def forward_extend(
373
498
  self,
374
499
  q: torch.Tensor,
@@ -380,6 +505,15 @@ class FlashInferMLAAttnBackend(AttentionBackend):
380
505
  q_rope: Optional[torch.Tensor] = None,
381
506
  k_rope: Optional[torch.Tensor] = None,
382
507
  ):
508
+ if (
509
+ forward_batch.attn_attend_prefix_cache is not None
510
+ and forward_batch.mha_return_lse
511
+ ): # MHA Chunk
512
+ assert self.enable_chunk_kv
513
+ assert q_rope is None
514
+ assert k_rope is None
515
+ o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
516
+ return o1, s1
383
517
 
384
518
  cache_loc = forward_batch.out_cache_loc
385
519
  logits_soft_cap = layer.logit_cap
@@ -400,7 +534,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
400
534
  q_rope = q_rope.view(
401
535
  -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
402
536
  )
403
-
404
537
  if self.forward_metadata.use_ragged:
405
538
  # ragged prefill
406
539
  if q_rope is not None:
@@ -410,8 +543,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
410
543
  k = torch.cat([k, k_rope], dim=-1)
411
544
  o = self.prefill_wrapper_ragged.forward(
412
545
  qall,
413
- k.view(-1, layer.tp_k_head_num, layer.head_dim),
414
- v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
546
+ k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
547
+ v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
415
548
  causal=True,
416
549
  sm_scale=layer.scaling,
417
550
  logits_soft_cap=logits_soft_cap,
@@ -421,6 +554,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
421
554
  k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
422
555
  q.dtype
423
556
  )
557
+ k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
558
+
424
559
  if q_rope is None:
425
560
  qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
426
561
  q, q_rope = (
@@ -482,17 +617,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
482
617
  q_nope = reshaped_q[:, :, : layer.v_head_dim]
483
618
  q_rope = reshaped_q[:, :, layer.v_head_dim :]
484
619
 
485
- k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
620
+ k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
486
621
  q.dtype
487
622
  )
623
+ k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
488
624
 
489
625
  o = q_nope.new_empty(q_nope.shape)
490
- # Direct call to run without the wrapper
491
626
  o = decode_wrapper.run(
492
627
  q_nope,
493
628
  q_rope,
494
- k_buffer[:, :, : layer.v_head_dim],
495
- k_buffer[:, :, layer.v_head_dim :],
629
+ k_buf[:, :, : layer.v_head_dim],
630
+ k_buf[:, :, layer.v_head_dim :],
496
631
  out=o,
497
632
  )
498
633
 
@@ -511,9 +646,10 @@ class FlashInferMLAIndicesUpdaterDecode:
511
646
  self.scaling = model_runner.model_config.scaling
512
647
  self.data_type = model_runner.dtype
513
648
  self.attn_backend = attn_backend
514
-
649
+ self.page_size = model_runner.page_size
515
650
  # Buffers and wrappers
516
651
  self.kv_indptr = attn_backend.kv_indptr
652
+ self.kv_indices = attn_backend.kv_indices
517
653
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
518
654
  self.q_indptr = attn_backend.q_indptr_decode
519
655
 
@@ -557,13 +693,17 @@ class FlashInferMLAIndicesUpdaterDecode:
557
693
  kv_lens = paged_kernel_lens.to(torch.int32)
558
694
  sm_scale = self.scaling
559
695
  if spec_info is None:
560
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
696
+ num_pages_per_req = (
697
+ paged_kernel_lens + self.page_size - 1
698
+ ) // self.page_size
699
+ kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
561
700
  kv_indptr = kv_indptr[: bs + 1]
562
701
  kv_indices = (
563
- torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
702
+ self.kv_indices[: kv_indptr[-1]]
564
703
  if not init_metadata_replay
565
704
  else fast_decode_kwargs["kv_indices"]
566
705
  )
706
+
567
707
  create_flashinfer_kv_indices_triton[(bs,)](
568
708
  self.req_to_token,
569
709
  req_pool_indices,
@@ -572,39 +712,40 @@ class FlashInferMLAIndicesUpdaterDecode:
572
712
  None,
573
713
  kv_indices,
574
714
  self.req_to_token.shape[1],
715
+ self.page_size,
575
716
  )
576
717
  else:
577
718
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
578
719
 
579
720
  if not init_metadata_replay:
580
721
  wrapper.plan(
581
- q_indptr,
582
- kv_indptr,
583
- kv_indices,
584
- kv_lens,
585
- self.num_local_heads,
586
- self.kv_lora_rank,
587
- self.qk_rope_head_dim,
588
- 1,
589
- False,
590
- sm_scale,
591
- self.data_type,
592
- self.data_type,
722
+ qo_indptr=q_indptr,
723
+ kv_indptr=kv_indptr,
724
+ kv_indices=kv_indices,
725
+ kv_len_arr=kv_lens,
726
+ num_heads=self.num_local_heads,
727
+ head_dim_ckv=self.kv_lora_rank,
728
+ head_dim_kpe=self.qk_rope_head_dim,
729
+ page_size=self.page_size,
730
+ causal=False,
731
+ sm_scale=sm_scale,
732
+ q_data_type=self.data_type,
733
+ kv_data_type=self.data_type,
593
734
  )
594
735
  else:
595
736
  wrapper.plan(
596
- fast_decode_kwargs["qo_indptr_cpu"],
597
- fast_decode_kwargs["kv_indptr_cpu"],
598
- kv_indices,
599
- fast_decode_kwargs["kv_len_arr_cpu"],
600
- self.num_local_heads,
601
- self.kv_lora_rank,
602
- self.qk_rope_head_dim,
603
- 1,
604
- False,
605
- sm_scale,
606
- self.data_type,
607
- self.data_type,
737
+ qo_indptr_cpu=fast_decode_kwargs["qo_indptr_cpu"],
738
+ kv_indptr_cpu=fast_decode_kwargs["kv_indptr_cpu"],
739
+ kv_indices=kv_indices,
740
+ kv_len_arr_cpu=fast_decode_kwargs["kv_len_arr_cpu"],
741
+ num_heads=self.num_local_heads,
742
+ head_dim_ckv=self.kv_lora_rank,
743
+ head_dim_kpe=self.qk_rope_head_dim,
744
+ page_size=self.page_size,
745
+ causal=False,
746
+ sm_scale=sm_scale,
747
+ q_data_type=self.data_type,
748
+ kv_data_type=self.data_type,
608
749
  )
609
750
 
610
751
 
@@ -626,12 +767,14 @@ class FlashInferMLAIndicesUpdaterPrefill:
626
767
  # Buffers and wrappers
627
768
  self.kv_indptr = attn_backend.kv_indptr
628
769
  self.qo_indptr = attn_backend.qo_indptr
770
+ self.kv_indices = attn_backend.kv_indices
629
771
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
630
772
  self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
773
+ self.page_size = model_runner.page_size
631
774
 
632
775
  def update(
633
776
  self,
634
- req_pool_indices: torch.Tnesor,
777
+ req_pool_indices: torch.Tensor,
635
778
  seq_lens: torch.Tensor,
636
779
  seq_lens_sum: int,
637
780
  prefix_lens: torch.Tensor,
@@ -645,7 +788,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
645
788
  else:
646
789
  paged_kernel_lens = seq_lens
647
790
  paged_kernel_lens_sum = seq_lens_sum
648
-
649
791
  self.call_begin_forward(
650
792
  self.prefill_wrapper_ragged,
651
793
  prefill_wrapper_paged,
@@ -679,13 +821,12 @@ class FlashInferMLAIndicesUpdaterPrefill:
679
821
 
680
822
  if spec_info is None:
681
823
  assert len(seq_lens) == len(req_pool_indices)
682
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
824
+ num_pages_per_req = (
825
+ paged_kernel_lens + self.page_size - 1
826
+ ) // self.page_size
827
+ kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
683
828
  kv_indptr = kv_indptr[: bs + 1]
684
- kv_indices = torch.empty(
685
- paged_kernel_lens_sum,
686
- dtype=torch.int32,
687
- device=req_pool_indices.device,
688
- )
829
+ kv_indices = self.kv_indices[: kv_indptr[-1]]
689
830
  create_flashinfer_kv_indices_triton[(bs,)](
690
831
  self.req_to_token,
691
832
  req_pool_indices,
@@ -694,6 +835,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
694
835
  None,
695
836
  kv_indices,
696
837
  self.req_to_token.shape[1],
838
+ self.page_size,
697
839
  )
698
840
  qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
699
841
  qo_indptr = qo_indptr[: bs + 1]
@@ -711,7 +853,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
711
853
  self.req_to_token,
712
854
  )
713
855
  )
714
-
715
856
  if use_ragged:
716
857
  # ragged prefill
717
858
  wrapper_ragged.begin_forward(
@@ -722,23 +863,30 @@ class FlashInferMLAIndicesUpdaterPrefill:
722
863
  head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
723
864
  head_dim_vo=self.v_head_dim,
724
865
  q_data_type=self.q_data_type,
866
+ causal=True,
725
867
  )
726
868
  else:
727
869
  # mla paged prefill
728
- kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
870
+ if spec_info is not None:
871
+ assert (
872
+ self.page_size == 1
873
+ ), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
874
+ kv_lens = kv_indptr[1:] - kv_indptr[:-1]
875
+ else:
876
+ kv_lens = paged_kernel_lens.to(torch.int32)
729
877
  wrapper_paged.plan(
730
- qo_indptr,
731
- kv_indptr,
732
- kv_indices,
733
- kv_len_arr,
734
- self.num_local_heads,
735
- self.kv_lora_rank,
736
- self.qk_rope_head_dim,
737
- 1,
738
- True,
739
- sm_scale,
740
- self.q_data_type,
741
- self.data_type,
878
+ qo_indptr=qo_indptr,
879
+ kv_indptr=kv_indptr,
880
+ kv_indices=kv_indices,
881
+ kv_len_arr=kv_lens,
882
+ num_heads=self.num_local_heads,
883
+ head_dim_ckv=self.kv_lora_rank,
884
+ head_dim_kpe=self.qk_rope_head_dim,
885
+ page_size=self.page_size,
886
+ causal=True,
887
+ sm_scale=sm_scale,
888
+ q_data_type=self.q_data_type,
889
+ kv_data_type=self.data_type,
742
890
  )
743
891
 
744
892
 
@@ -833,6 +981,7 @@ class FlashInferMLAMultiStepDraftBackend:
833
981
  call_fn(i, forward_batch)
834
982
 
835
983
  def init_forward_metadata(self, forward_batch: ForwardBatch):
984
+
836
985
  kv_indices = torch.zeros(
837
986
  (
838
987
  self.speculative_num_steps,
@@ -868,6 +1017,7 @@ class FlashInferMLAMultiStepDraftBackend:
868
1017
  )
869
1018
 
870
1019
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
1020
+
871
1021
  def call_fn(i, forward_batch):
872
1022
  self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
873
1023
  forward_batch.batch_size,
@@ -884,6 +1034,7 @@ class FlashInferMLAMultiStepDraftBackend:
884
1034
  def init_forward_metadata_replay_cuda_graph(
885
1035
  self, forward_batch: ForwardBatch, bs: int
886
1036
  ):
1037
+
887
1038
  def call_fn(i, forward_batch):
888
1039
  self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
889
1040
  bs,