sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +375 -51
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -24,9 +24,7 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
24
24
 
25
25
  from sglang.global_config import global_config
26
26
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
27
- from sglang.srt.layers.attention.flashinfer_backend import (
28
- create_flashinfer_kv_indices_triton,
29
- )
27
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
30
28
  from sglang.srt.layers.dp_attention import get_attention_tp_size
31
29
  from sglang.srt.layers.utils import is_sm100_supported
32
30
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -61,6 +59,115 @@ class PrefillMetadata:
61
59
  global_workspace_buffer = None
62
60
 
63
61
 
62
+ class FlashInferMhaChunkKVRunner:
63
+ def __init__(
64
+ self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
65
+ ):
66
+ # Parse Constants
67
+ self.num_local_heads = (
68
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
69
+ )
70
+ self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
71
+ self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
72
+ self.v_head_dim = model_runner.model_config.v_head_dim
73
+ self.data_type = model_runner.dtype
74
+ self.q_data_type = model_runner.dtype
75
+
76
+ # Buffers and wrappers
77
+ self.qo_indptr = attn_backend.qo_indptr
78
+ self.workspace_buffer = attn_backend.workspace_buffer
79
+ self.fmha_backend = attn_backend.fmha_backend
80
+
81
+ self.chunk_ragged_wrappers = []
82
+ self.ragged_wrapper = attn_backend.prefill_wrapper_ragged
83
+
84
+ def update_prefix_chunks(self, num_prefix_chunks: int):
85
+ while num_prefix_chunks > len(self.chunk_ragged_wrappers):
86
+ ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper(
87
+ self.workspace_buffer, "NHD", backend=self.fmha_backend
88
+ )
89
+ self.chunk_ragged_wrappers.append(ragged_wrapper)
90
+
91
+ def update_wrapper(
92
+ self,
93
+ forward_batch: ForwardBatch,
94
+ ):
95
+ assert forward_batch.num_prefix_chunks is not None
96
+ num_prefix_chunks = forward_batch.num_prefix_chunks
97
+ self.update_prefix_chunks(num_prefix_chunks)
98
+
99
+ prefix_lens = forward_batch.extend_prefix_lens
100
+ seq_lens = forward_batch.seq_lens
101
+
102
+ bs = len(seq_lens)
103
+ qo_indptr = self.qo_indptr
104
+ qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
105
+ qo_indptr = qo_indptr[: bs + 1]
106
+
107
+ for chunk_idx in range(forward_batch.num_prefix_chunks):
108
+ # MHA for chunked prefix kv cache when running model with MLA
109
+ assert forward_batch.prefix_chunk_idx is not None
110
+ assert forward_batch.prefix_chunk_cu_seq_lens is not None
111
+ assert forward_batch.prefix_chunk_max_seq_lens is not None
112
+
113
+ kv_indptr = forward_batch.prefix_chunk_cu_seq_lens[chunk_idx]
114
+ wrapper = self.chunk_ragged_wrappers[chunk_idx]
115
+ wrapper.begin_forward(
116
+ qo_indptr=qo_indptr,
117
+ kv_indptr=kv_indptr,
118
+ num_qo_heads=self.num_local_heads,
119
+ num_kv_heads=self.num_local_heads,
120
+ head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
121
+ head_dim_vo=self.v_head_dim,
122
+ q_data_type=self.q_data_type,
123
+ causal=False,
124
+ )
125
+ # ragged prefill
126
+ self.ragged_wrapper.begin_forward(
127
+ qo_indptr=qo_indptr,
128
+ kv_indptr=qo_indptr,
129
+ num_qo_heads=self.num_local_heads,
130
+ num_kv_heads=self.num_local_heads,
131
+ head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
132
+ head_dim_vo=self.v_head_dim,
133
+ q_data_type=self.q_data_type,
134
+ causal=True,
135
+ )
136
+
137
+ def forward(
138
+ self,
139
+ q: torch.Tensor,
140
+ k: torch.Tensor,
141
+ v: torch.Tensor,
142
+ layer: RadixAttention,
143
+ forward_batch: ForwardBatch,
144
+ ):
145
+ logits_soft_cap = layer.logit_cap
146
+ if forward_batch.attn_attend_prefix_cache:
147
+ chunk_idx = forward_batch.prefix_chunk_idx
148
+ assert chunk_idx >= 0
149
+ wrapper = self.chunk_ragged_wrappers[chunk_idx]
150
+ o1, s1 = wrapper.forward_return_lse(
151
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
152
+ k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
153
+ v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
154
+ causal=False,
155
+ sm_scale=layer.scaling,
156
+ logits_soft_cap=logits_soft_cap,
157
+ )
158
+ else:
159
+ o1, s1 = self.ragged_wrapper.forward_return_lse(
160
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
161
+ k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
162
+ v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
163
+ causal=True,
164
+ sm_scale=layer.scaling,
165
+ logits_soft_cap=logits_soft_cap,
166
+ )
167
+
168
+ return o1, s1
169
+
170
+
64
171
  class FlashInferMLAAttnBackend(AttentionBackend):
65
172
  """Flashinfer attention kernels."""
66
173
 
@@ -72,11 +179,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
72
179
  q_indptr_decode_buf: Optional[torch.Tensor] = None,
73
180
  ):
74
181
  super().__init__()
75
-
76
182
  # Parse constants
77
183
  self.max_context_len = model_runner.model_config.context_len
78
184
  self.device = model_runner.device
79
185
  self.skip_prefill = skip_prefill
186
+ self.enable_chunk_kv = (
187
+ not skip_prefill
188
+ and global_server_args_dict["disaggregation_mode"] != "decode"
189
+ and not global_server_args_dict["disable_chunked_prefix_cache"]
190
+ and not global_server_args_dict["flashinfer_mla_disable_ragged"]
191
+ )
192
+ self.page_size = model_runner.page_size
80
193
 
81
194
  # Allocate buffers
82
195
  global global_workspace_buffer
@@ -97,23 +210,33 @@ class FlashInferMLAAttnBackend(AttentionBackend):
97
210
  else:
98
211
  self.kv_indptr = kv_indptr_buf
99
212
 
213
+ self.kv_indices = torch.empty(
214
+ (max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
215
+ dtype=torch.int32,
216
+ device=model_runner.device,
217
+ )
218
+
100
219
  if not self.skip_prefill:
101
220
  self.qo_indptr = torch.zeros(
102
221
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
103
222
  )
104
223
 
105
224
  if q_indptr_decode_buf is None:
225
+ # A hack to pre-initialize large batch size for dp attention
226
+ if model_runner.server_args.enable_dp_attention:
227
+ max_bs = model_runner.server_args.dp_size * max_bs
106
228
  self.q_indptr_decode = torch.arange(
107
229
  0, max_bs + 1, dtype=torch.int32, device=model_runner.device
108
230
  )
231
+
109
232
  else:
110
233
  self.q_indptr_decode = q_indptr_decode_buf
111
234
 
112
- fmha_backend = "auto"
235
+ self.fmha_backend = "auto"
113
236
  if is_sm100_supported():
114
- fmha_backend = "cutlass"
237
+ self.fmha_backend = "cutlass"
115
238
  self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
116
- self.workspace_buffer, "NHD", backend=fmha_backend
239
+ self.workspace_buffer, "NHD", backend=self.fmha_backend
117
240
  )
118
241
 
119
242
  if not self.skip_prefill:
@@ -137,6 +260,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
137
260
  self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
138
261
  model_runner, self
139
262
  )
263
+ if self.enable_chunk_kv:
264
+ self.mha_chunk_kv_cache = FlashInferMhaChunkKVRunner(model_runner, self)
140
265
 
141
266
  self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
142
267
  model_runner, self
@@ -148,6 +273,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
148
273
  self.prefill_cuda_graph_metadata = {} # For verify
149
274
 
150
275
  def init_forward_metadata(self, forward_batch: ForwardBatch):
276
+
151
277
  if forward_batch.forward_mode.is_decode_or_idle():
152
278
  self.indices_updater_decode.update(
153
279
  forward_batch.req_pool_indices,
@@ -205,16 +331,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
205
331
  max_num_tokens: int,
206
332
  kv_indices_buf: Optional[torch.Tensor] = None,
207
333
  ):
208
- if kv_indices_buf is None:
209
- cuda_graph_kv_indices = torch.zeros(
210
- (max_bs * self.max_context_len,),
211
- dtype=torch.int32,
212
- device="cuda",
213
- )
214
- else:
215
- cuda_graph_kv_indices = kv_indices_buf
216
-
217
- self.cuda_graph_kv_indices = cuda_graph_kv_indices
334
+ self.cuda_graph_kv_indices = (
335
+ self.kv_indices.clone() if kv_indices_buf is None else kv_indices_buf
336
+ )
218
337
  self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
219
338
  self.cuda_graph_kv_indptr = self.kv_indptr.clone()
220
339
  self.cuda_graph_kv_lens = torch.ones(
@@ -240,6 +359,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
240
359
  forward_mode: ForwardMode,
241
360
  spec_info: Optional[SpecInfo],
242
361
  ):
362
+
243
363
  if forward_mode.is_decode_or_idle():
244
364
  decode_wrapper = BatchMLAPagedAttentionWrapper(
245
365
  self.workspace_buffer,
@@ -250,7 +370,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
250
370
  kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
251
371
  backend="auto",
252
372
  )
253
-
254
373
  seq_lens_sum = seq_lens.sum().item()
255
374
  self.indices_updater_decode.update(
256
375
  req_pool_indices,
@@ -321,11 +440,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
321
440
  spec_info: Optional[SpecInfo],
322
441
  seq_lens_cpu: Optional[torch.Tensor],
323
442
  ):
443
+
324
444
  if forward_mode.is_decode_or_idle():
325
445
  assert seq_lens_cpu is not None
326
446
  kv_len_arr_cpu = seq_lens_cpu[:bs]
447
+ num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
327
448
  self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
328
- kv_len_arr_cpu, dim=0
449
+ num_pages_per_req, dim=0
329
450
  )
330
451
  self.fast_decode_kwargs.update(
331
452
  {
@@ -334,7 +455,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
334
455
  "kv_len_arr_cpu": kv_len_arr_cpu,
335
456
  }
336
457
  )
337
-
338
458
  self.indices_updater_decode.update(
339
459
  req_pool_indices[:bs],
340
460
  seq_lens[:bs],
@@ -370,6 +490,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
370
490
  def get_cuda_graph_seq_len_fill_value(self):
371
491
  return 1
372
492
 
493
+ def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
494
+ """Init the metadata for a forward pass."""
495
+ self.mha_chunk_kv_cache.update_wrapper(forward_batch)
496
+
373
497
  def forward_extend(
374
498
  self,
375
499
  q: torch.Tensor,
@@ -381,6 +505,15 @@ class FlashInferMLAAttnBackend(AttentionBackend):
381
505
  q_rope: Optional[torch.Tensor] = None,
382
506
  k_rope: Optional[torch.Tensor] = None,
383
507
  ):
508
+ if (
509
+ forward_batch.attn_attend_prefix_cache is not None
510
+ and forward_batch.mha_return_lse
511
+ ): # MHA Chunk
512
+ assert self.enable_chunk_kv
513
+ assert q_rope is None
514
+ assert k_rope is None
515
+ o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
516
+ return o1, s1
384
517
 
385
518
  cache_loc = forward_batch.out_cache_loc
386
519
  logits_soft_cap = layer.logit_cap
@@ -401,7 +534,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
401
534
  q_rope = q_rope.view(
402
535
  -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
403
536
  )
404
-
405
537
  if self.forward_metadata.use_ragged:
406
538
  # ragged prefill
407
539
  if q_rope is not None:
@@ -411,8 +543,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
411
543
  k = torch.cat([k, k_rope], dim=-1)
412
544
  o = self.prefill_wrapper_ragged.forward(
413
545
  qall,
414
- k.view(-1, layer.tp_k_head_num, layer.head_dim),
415
- v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
546
+ k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
547
+ v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
416
548
  causal=True,
417
549
  sm_scale=layer.scaling,
418
550
  logits_soft_cap=logits_soft_cap,
@@ -422,6 +554,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
422
554
  k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
423
555
  q.dtype
424
556
  )
557
+ k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
558
+
425
559
  if q_rope is None:
426
560
  qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
427
561
  q, q_rope = (
@@ -483,17 +617,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
483
617
  q_nope = reshaped_q[:, :, : layer.v_head_dim]
484
618
  q_rope = reshaped_q[:, :, layer.v_head_dim :]
485
619
 
486
- k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
620
+ k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
487
621
  q.dtype
488
622
  )
623
+ k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
489
624
 
490
625
  o = q_nope.new_empty(q_nope.shape)
491
- # Direct call to run without the wrapper
492
626
  o = decode_wrapper.run(
493
627
  q_nope,
494
628
  q_rope,
495
- k_buffer[:, :, : layer.v_head_dim],
496
- k_buffer[:, :, layer.v_head_dim :],
629
+ k_buf[:, :, : layer.v_head_dim],
630
+ k_buf[:, :, layer.v_head_dim :],
497
631
  out=o,
498
632
  )
499
633
 
@@ -512,9 +646,10 @@ class FlashInferMLAIndicesUpdaterDecode:
512
646
  self.scaling = model_runner.model_config.scaling
513
647
  self.data_type = model_runner.dtype
514
648
  self.attn_backend = attn_backend
515
-
649
+ self.page_size = model_runner.page_size
516
650
  # Buffers and wrappers
517
651
  self.kv_indptr = attn_backend.kv_indptr
652
+ self.kv_indices = attn_backend.kv_indices
518
653
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
519
654
  self.q_indptr = attn_backend.q_indptr_decode
520
655
 
@@ -558,13 +693,17 @@ class FlashInferMLAIndicesUpdaterDecode:
558
693
  kv_lens = paged_kernel_lens.to(torch.int32)
559
694
  sm_scale = self.scaling
560
695
  if spec_info is None:
561
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
696
+ num_pages_per_req = (
697
+ paged_kernel_lens + self.page_size - 1
698
+ ) // self.page_size
699
+ kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
562
700
  kv_indptr = kv_indptr[: bs + 1]
563
701
  kv_indices = (
564
- torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
702
+ self.kv_indices[: kv_indptr[-1]]
565
703
  if not init_metadata_replay
566
704
  else fast_decode_kwargs["kv_indices"]
567
705
  )
706
+
568
707
  create_flashinfer_kv_indices_triton[(bs,)](
569
708
  self.req_to_token,
570
709
  req_pool_indices,
@@ -573,39 +712,40 @@ class FlashInferMLAIndicesUpdaterDecode:
573
712
  None,
574
713
  kv_indices,
575
714
  self.req_to_token.shape[1],
715
+ self.page_size,
576
716
  )
577
717
  else:
578
718
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
579
719
 
580
720
  if not init_metadata_replay:
581
721
  wrapper.plan(
582
- q_indptr,
583
- kv_indptr,
584
- kv_indices,
585
- kv_lens,
586
- self.num_local_heads,
587
- self.kv_lora_rank,
588
- self.qk_rope_head_dim,
589
- 1,
590
- False,
591
- sm_scale,
592
- self.data_type,
593
- self.data_type,
722
+ qo_indptr=q_indptr,
723
+ kv_indptr=kv_indptr,
724
+ kv_indices=kv_indices,
725
+ kv_len_arr=kv_lens,
726
+ num_heads=self.num_local_heads,
727
+ head_dim_ckv=self.kv_lora_rank,
728
+ head_dim_kpe=self.qk_rope_head_dim,
729
+ page_size=self.page_size,
730
+ causal=False,
731
+ sm_scale=sm_scale,
732
+ q_data_type=self.data_type,
733
+ kv_data_type=self.data_type,
594
734
  )
595
735
  else:
596
736
  wrapper.plan(
597
- fast_decode_kwargs["qo_indptr_cpu"],
598
- fast_decode_kwargs["kv_indptr_cpu"],
599
- kv_indices,
600
- fast_decode_kwargs["kv_len_arr_cpu"],
601
- self.num_local_heads,
602
- self.kv_lora_rank,
603
- self.qk_rope_head_dim,
604
- 1,
605
- False,
606
- sm_scale,
607
- self.data_type,
608
- self.data_type,
737
+ qo_indptr_cpu=fast_decode_kwargs["qo_indptr_cpu"],
738
+ kv_indptr_cpu=fast_decode_kwargs["kv_indptr_cpu"],
739
+ kv_indices=kv_indices,
740
+ kv_len_arr_cpu=fast_decode_kwargs["kv_len_arr_cpu"],
741
+ num_heads=self.num_local_heads,
742
+ head_dim_ckv=self.kv_lora_rank,
743
+ head_dim_kpe=self.qk_rope_head_dim,
744
+ page_size=self.page_size,
745
+ causal=False,
746
+ sm_scale=sm_scale,
747
+ q_data_type=self.data_type,
748
+ kv_data_type=self.data_type,
609
749
  )
610
750
 
611
751
 
@@ -627,12 +767,14 @@ class FlashInferMLAIndicesUpdaterPrefill:
627
767
  # Buffers and wrappers
628
768
  self.kv_indptr = attn_backend.kv_indptr
629
769
  self.qo_indptr = attn_backend.qo_indptr
770
+ self.kv_indices = attn_backend.kv_indices
630
771
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
631
772
  self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
773
+ self.page_size = model_runner.page_size
632
774
 
633
775
  def update(
634
776
  self,
635
- req_pool_indices: torch.Tnesor,
777
+ req_pool_indices: torch.Tensor,
636
778
  seq_lens: torch.Tensor,
637
779
  seq_lens_sum: int,
638
780
  prefix_lens: torch.Tensor,
@@ -646,7 +788,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
646
788
  else:
647
789
  paged_kernel_lens = seq_lens
648
790
  paged_kernel_lens_sum = seq_lens_sum
649
-
650
791
  self.call_begin_forward(
651
792
  self.prefill_wrapper_ragged,
652
793
  prefill_wrapper_paged,
@@ -680,13 +821,12 @@ class FlashInferMLAIndicesUpdaterPrefill:
680
821
 
681
822
  if spec_info is None:
682
823
  assert len(seq_lens) == len(req_pool_indices)
683
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
824
+ num_pages_per_req = (
825
+ paged_kernel_lens + self.page_size - 1
826
+ ) // self.page_size
827
+ kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
684
828
  kv_indptr = kv_indptr[: bs + 1]
685
- kv_indices = torch.empty(
686
- paged_kernel_lens_sum,
687
- dtype=torch.int32,
688
- device=req_pool_indices.device,
689
- )
829
+ kv_indices = self.kv_indices[: kv_indptr[-1]]
690
830
  create_flashinfer_kv_indices_triton[(bs,)](
691
831
  self.req_to_token,
692
832
  req_pool_indices,
@@ -695,6 +835,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
695
835
  None,
696
836
  kv_indices,
697
837
  self.req_to_token.shape[1],
838
+ self.page_size,
698
839
  )
699
840
  qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
700
841
  qo_indptr = qo_indptr[: bs + 1]
@@ -712,7 +853,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
712
853
  self.req_to_token,
713
854
  )
714
855
  )
715
-
716
856
  if use_ragged:
717
857
  # ragged prefill
718
858
  wrapper_ragged.begin_forward(
@@ -723,23 +863,30 @@ class FlashInferMLAIndicesUpdaterPrefill:
723
863
  head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
724
864
  head_dim_vo=self.v_head_dim,
725
865
  q_data_type=self.q_data_type,
866
+ causal=True,
726
867
  )
727
868
  else:
728
869
  # mla paged prefill
729
- kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
870
+ if spec_info is not None:
871
+ assert (
872
+ self.page_size == 1
873
+ ), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
874
+ kv_lens = kv_indptr[1:] - kv_indptr[:-1]
875
+ else:
876
+ kv_lens = paged_kernel_lens.to(torch.int32)
730
877
  wrapper_paged.plan(
731
- qo_indptr,
732
- kv_indptr,
733
- kv_indices,
734
- kv_len_arr,
735
- self.num_local_heads,
736
- self.kv_lora_rank,
737
- self.qk_rope_head_dim,
738
- 1,
739
- True,
740
- sm_scale,
741
- self.q_data_type,
742
- self.data_type,
878
+ qo_indptr=qo_indptr,
879
+ kv_indptr=kv_indptr,
880
+ kv_indices=kv_indices,
881
+ kv_len_arr=kv_lens,
882
+ num_heads=self.num_local_heads,
883
+ head_dim_ckv=self.kv_lora_rank,
884
+ head_dim_kpe=self.qk_rope_head_dim,
885
+ page_size=self.page_size,
886
+ causal=True,
887
+ sm_scale=sm_scale,
888
+ q_data_type=self.q_data_type,
889
+ kv_data_type=self.data_type,
743
890
  )
744
891
 
745
892
 
@@ -834,6 +981,7 @@ class FlashInferMLAMultiStepDraftBackend:
834
981
  call_fn(i, forward_batch)
835
982
 
836
983
  def init_forward_metadata(self, forward_batch: ForwardBatch):
984
+
837
985
  kv_indices = torch.zeros(
838
986
  (
839
987
  self.speculative_num_steps,
@@ -869,6 +1017,7 @@ class FlashInferMLAMultiStepDraftBackend:
869
1017
  )
870
1018
 
871
1019
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
1020
+
872
1021
  def call_fn(i, forward_batch):
873
1022
  self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
874
1023
  forward_batch.batch_size,
@@ -885,6 +1034,7 @@ class FlashInferMLAMultiStepDraftBackend:
885
1034
  def init_forward_metadata_replay_cuda_graph(
886
1035
  self, forward_batch: ForwardBatch, bs: int
887
1036
  ):
1037
+
888
1038
  def call_fn(i, forward_batch):
889
1039
  self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
890
1040
  bs,