sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -119,6 +119,7 @@ class Indexer(CustomOp):
119
119
  prefix: str = "",
120
120
  quant_config: Optional[QuantizationConfig] = None,
121
121
  alt_stream: Optional[torch.cuda.Stream] = None,
122
+ fuse_wk_and_weights_proj: bool = False,
122
123
  ):
123
124
  super().__init__()
124
125
  self.hidden_size = hidden_size
@@ -129,6 +130,7 @@ class Indexer(CustomOp):
129
130
  self.q_lora_rank = q_lora_rank
130
131
  self.layer_id = layer_id
131
132
  self.alt_stream = alt_stream
133
+ self.fuse_wk_and_weights_proj = fuse_wk_and_weights_proj
132
134
  if is_cuda():
133
135
  self.sm_count = deep_gemm.get_num_sms()
134
136
  self.half_device_sm_count = align(self.sm_count // 2, 8)
@@ -140,21 +142,29 @@ class Indexer(CustomOp):
140
142
  quant_config=quant_config,
141
143
  prefix=add_prefix("wq_b", prefix),
142
144
  )
143
- self.wk = ReplicatedLinear(
144
- self.hidden_size,
145
- self.head_dim,
146
- bias=False,
147
- quant_config=quant_config,
148
- prefix=add_prefix("wk", prefix),
149
- )
145
+ if self.fuse_wk_and_weights_proj:
146
+ self.fused_wk_and_weights_proj = ReplicatedLinear(
147
+ self.hidden_size,
148
+ self.head_dim + self.n_heads,
149
+ bias=False,
150
+ prefix=add_prefix("fused_wk_and_weights_proj", prefix),
151
+ )
152
+ else:
153
+ self.wk = ReplicatedLinear(
154
+ self.hidden_size,
155
+ self.head_dim,
156
+ bias=False,
157
+ quant_config=quant_config,
158
+ prefix=add_prefix("wk", prefix),
159
+ )
160
+ # NOTE: weight_proj is not quantized
161
+ self.weights_proj = ReplicatedLinear(
162
+ self.hidden_size,
163
+ self.n_heads,
164
+ bias=False,
165
+ prefix=add_prefix("weights_proj", prefix),
166
+ )
150
167
  self.k_norm = V32LayerNorm(self.head_dim)
151
- # NOTE: weight_proj is not quantized
152
- self.weights_proj = ReplicatedLinear(
153
- self.hidden_size,
154
- self.n_heads,
155
- bias=False,
156
- prefix=add_prefix("weights_proj", prefix),
157
- )
158
168
  self.rotary_emb = get_rope_wrapper(
159
169
  rope_head_dim,
160
170
  rotary_dim=rope_head_dim,
@@ -169,8 +179,7 @@ class Indexer(CustomOp):
169
179
  self.softmax_scale = self.head_dim**-0.5
170
180
 
171
181
  @torch.compile(dynamic=True)
172
- def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
173
- weights, _ = self.weights_proj(x)
182
+ def _get_logits_head_gate(self, weights: torch.Tensor, q_scale: torch.Tensor):
174
183
  weights = weights * self.n_heads**-0.5
175
184
  weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
176
185
  return weights
@@ -182,7 +191,7 @@ class Indexer(CustomOp):
182
191
  positions: torch.Tensor,
183
192
  enable_dual_stream: bool,
184
193
  ):
185
-
194
+ weights = None
186
195
  if enable_dual_stream:
187
196
  current_stream = torch.cuda.current_stream()
188
197
  self.alt_stream.wait_stream(current_stream)
@@ -199,7 +208,12 @@ class Indexer(CustomOp):
199
208
  )
200
209
  with torch.cuda.stream(self.alt_stream):
201
210
  # TODO we should also put DeepGEMM half SM here?
202
- key, _ = self.wk(x)
211
+ if self.fuse_wk_and_weights_proj:
212
+ key, weights = self.fused_wk_and_weights_proj(x)[0].split(
213
+ [self.head_dim, self.n_heads], dim=-1
214
+ )
215
+ else:
216
+ key, _ = self.wk(x)
203
217
  key = self.k_norm(key)
204
218
 
205
219
  k_rope, _ = torch.split(
@@ -217,7 +231,12 @@ class Indexer(CustomOp):
217
231
  query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
218
232
  )
219
233
 
220
- key, _ = self.wk(x)
234
+ if self.fuse_wk_and_weights_proj:
235
+ key, weights = self.fused_wk_and_weights_proj(x)[0].split(
236
+ [self.head_dim, self.n_heads], dim=-1
237
+ )
238
+ else:
239
+ key, _ = self.wk(x)
221
240
  key = self.k_norm(key)
222
241
  k_rope, _ = torch.split(
223
242
  key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
@@ -240,7 +259,7 @@ class Indexer(CustomOp):
240
259
  query = rotate_activation(query)
241
260
  key = rotate_activation(key)
242
261
 
243
- return query, key
262
+ return query, key, weights
244
263
 
245
264
  def _get_topk_paged(
246
265
  self,
@@ -490,7 +509,9 @@ class Indexer(CustomOp):
490
509
  if metadata is None:
491
510
  return None
492
511
 
493
- query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
512
+ query, key, weights = self._get_q_k_bf16(
513
+ q_lora, x, positions, enable_dual_stream
514
+ )
494
515
 
495
516
  if enable_dual_stream:
496
517
  current_stream = torch.cuda.current_stream()
@@ -517,7 +538,9 @@ class Indexer(CustomOp):
517
538
  index_k_scale=k_scale,
518
539
  )
519
540
 
520
- weights = self._get_logits_head_gate(x, q_scale)
541
+ if not self.fuse_wk_and_weights_proj:
542
+ weights, _ = self.weights_proj(x)
543
+ weights = self._get_logits_head_gate(weights, q_scale)
521
544
 
522
545
  if is_cuda():
523
546
  assert forward_batch.seq_lens_cpu is not None
@@ -206,6 +206,8 @@ def _quantize_k_cache_fast_kernel(
206
206
 
207
207
 
208
208
  if __name__ == "__main__":
209
+ import dequant_k_cache
210
+
209
211
  for num_blocks, block_size in [
210
212
  (1, 1),
211
213
  (10, 64),
@@ -217,21 +219,9 @@ if __name__ == "__main__":
217
219
  dtype=torch.bfloat16,
218
220
  device="cuda",
219
221
  )
220
- # temp debug
221
- # input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
222
222
 
223
223
  ref_quant = _quantize_k_cache_slow(input_k_cache)
224
224
  actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
225
- # print(f"{input_k_cache=}")
226
- # print(f"{ref_quant=}")
227
- # print(f"{actual_quant=}")
228
- # print(f"{ref_quant == actual_quant=}")
229
- # print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
230
- # print(f"{ref_quant.view(torch.bfloat16)=}")
231
- # print(f"{actual_quant.view(torch.bfloat16)=}")
232
- # assert torch.all(ref_quant == actual_quant)
233
-
234
- import dequant_k_cache
235
225
 
236
226
  ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)
237
227
  ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)
@@ -252,4 +242,46 @@ if __name__ == "__main__":
252
242
  ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2
253
243
  )
254
244
 
245
+ # test dequant_k_cache_paged
246
+ page_table_1 = torch.arange(
247
+ num_blocks * block_size, dtype=torch.int32, device="cuda"
248
+ )
249
+ actual_dequant_paged = dequant_k_cache.dequantize_k_cache_paged(
250
+ actual_quant, page_table_1
251
+ ).reshape(actual_actual_dequant.shape)
252
+ print(f"{torch.mean(actual_actual_dequant - actual_dequant_paged)=}")
253
+ torch.testing.assert_close(
254
+ ref_ref_dequant, actual_dequant_paged, atol=0.2, rtol=0.2
255
+ )
256
+
255
257
  print("Passed")
258
+ print("Do benchmark...")
259
+
260
+ for num_blocks, block_size in [
261
+ (1, 64),
262
+ (64, 64),
263
+ (128, 64),
264
+ (256, 64),
265
+ (512, 64),
266
+ (1024, 64),
267
+ (2048, 64),
268
+ ]:
269
+ dim_nope_and_rope = 512 + 64
270
+
271
+ input_k_cache = torch.randn(
272
+ (num_blocks, block_size, 1, dim_nope_and_rope),
273
+ dtype=torch.bfloat16,
274
+ device="cuda",
275
+ )
276
+
277
+ actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
278
+
279
+ page_table_1 = torch.arange(
280
+ num_blocks * block_size, dtype=torch.int32, device="cuda"
281
+ )
282
+
283
+ def run_ans():
284
+ return dequant_k_cache.dequantize_k_cache_paged(actual_quant, page_table_1)
285
+
286
+ ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000 # type: ignore
287
+ print(f"seq_kv: {num_blocks * block_size}, time: {ans_time * 1e6: 4.0f} us")
@@ -103,7 +103,7 @@ def transform_index_page_table_decode_ref(
103
103
  result = torch.empty_like(topk_indices, dtype=torch.int32)
104
104
  assert result.shape == topk_indices.shape
105
105
  torch.gather(
106
- page_table,
106
+ page_table.to(result.dtype),
107
107
  dim=1,
108
108
  index=topk_indices.clamp(min=0),
109
109
  out=result,
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
+ from enum import IntEnum, auto
4
5
  from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
5
6
 
6
7
  import torch
7
8
 
8
9
  from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
9
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
+ from sglang.srt.layers.attention.nsa.dequant_k_cache import dequantize_k_cache_paged
10
12
  from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
11
13
  from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
12
14
  from sglang.srt.layers.attention.nsa.transform_index import (
@@ -98,11 +100,27 @@ class NSAMetadata:
98
100
  nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
99
101
 
100
102
  flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
103
+ # The sum of sequence lengths for key, prefill only
104
+ seq_lens_sum: Optional[int] = None
105
+ # The flattened 1D page table with shape (seq_lens_sum,), prefill only
106
+ # this table is always with page_size = 1
107
+ page_table_1_flattened: Optional[torch.Tensor] = None
108
+ # The offset of topk indices in ragged kv, prefill only
109
+ # shape: (seq_lens_sum,)
110
+ topk_indices_offset: Optional[torch.Tensor] = None
111
+
112
+
113
+ class TopkTransformMethod(IntEnum):
114
+ # Transform topk indices to indices to the page table (page_size = 1)
115
+ PAGED = auto()
116
+ # Transform topk indices to indices to ragged kv (non-paged)
117
+ RAGGED = auto()
101
118
 
102
119
 
103
120
  @dataclass(frozen=True)
104
121
  class NSAIndexerMetadata(BaseIndexerMetadata):
105
122
  attn_metadata: NSAMetadata
123
+ topk_transform_method: TopkTransformMethod
106
124
 
107
125
  def get_seqlens_int32(self) -> torch.Tensor:
108
126
  return self.attn_metadata.cache_seqlens_int32
@@ -118,23 +136,36 @@ class NSAIndexerMetadata(BaseIndexerMetadata):
118
136
  logits: torch.Tensor,
119
137
  topk: int,
120
138
  ) -> torch.Tensor:
121
- from sgl_kernel import fast_topk_transform_fused, fast_topk_v2
139
+ from sgl_kernel import (
140
+ fast_topk_transform_fused,
141
+ fast_topk_transform_ragged_fused,
142
+ fast_topk_v2,
143
+ )
122
144
 
123
145
  if not NSA_FUSE_TOPK:
124
146
  return fast_topk_v2(logits, self.get_seqlens_expanded(), topk)
125
-
126
- # NOTE(dark): if fused, we return a transformed page table directly
127
- return fast_topk_transform_fused(
128
- score=logits,
129
- lengths=self.get_seqlens_expanded(),
130
- page_table_size_1=self.attn_metadata.page_table_1,
131
- cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
132
- topk=topk,
133
- )
147
+ elif self.topk_transform_method == TopkTransformMethod.PAGED:
148
+ # NOTE(dark): if fused, we return a transformed page table directly
149
+ return fast_topk_transform_fused(
150
+ score=logits,
151
+ lengths=self.get_seqlens_expanded(),
152
+ page_table_size_1=self.attn_metadata.page_table_1,
153
+ cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
154
+ topk=topk,
155
+ )
156
+ elif self.topk_transform_method == TopkTransformMethod.RAGGED:
157
+ return fast_topk_transform_ragged_fused(
158
+ score=logits,
159
+ lengths=self.get_seqlens_expanded(),
160
+ topk_indices_offset=self.attn_metadata.topk_indices_offset,
161
+ topk=topk,
162
+ )
163
+ else:
164
+ assert False, f"Unsupported {self.topk_transform_method = }"
134
165
 
135
166
 
136
167
  def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
137
- assert seqlens.dtype == torch.int32 and seqlens.is_cuda
168
+ assert seqlens.dtype == torch.int32
138
169
  return torch.nn.functional.pad(
139
170
  torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
140
171
  )
@@ -181,6 +212,7 @@ class NativeSparseAttnBackend(AttentionBackend):
181
212
  global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
182
213
  NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend
183
214
  NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend
215
+ self.enable_auto_select_prefill_impl = NSA_PREFILL_IMPL == "flashmla_auto"
184
216
 
185
217
  self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
186
218
 
@@ -231,10 +263,16 @@ class NativeSparseAttnBackend(AttentionBackend):
231
263
  cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
232
264
  assert forward_batch.seq_lens_cpu is not None
233
265
  max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item() + draft_token_num)
266
+ # [b, max_seqlen_k]
234
267
  page_table = forward_batch.req_to_token_pool.req_to_token[
235
268
  forward_batch.req_pool_indices, :max_seqlen_k
236
269
  ]
237
270
 
271
+ page_table_1_flattened = None
272
+ topk_indices_offset = None
273
+ self.set_nsa_prefill_impl(forward_batch)
274
+ topk_transform_method = self.get_topk_transform_method()
275
+
238
276
  if forward_batch.forward_mode.is_decode_or_idle():
239
277
  extend_seq_lens_cpu = [1] * batch_size
240
278
  max_seqlen_q = 1
@@ -295,6 +333,7 @@ class NativeSparseAttnBackend(AttentionBackend):
295
333
  else:
296
334
  max_seqlen_q = max_seqlen_k
297
335
  cu_seqlens_q = cu_seqlens_k
336
+
298
337
  seqlens_expanded = torch.cat(
299
338
  [
300
339
  torch.arange(
@@ -310,6 +349,24 @@ class NativeSparseAttnBackend(AttentionBackend):
310
349
  )
311
350
  ]
312
351
  )
352
+
353
+ if topk_transform_method == TopkTransformMethod.RAGGED:
354
+ page_table_1_flattened = torch.cat(
355
+ [
356
+ page_table[i, :kv_len]
357
+ for i, kv_len in enumerate(
358
+ forward_batch.seq_lens_cpu.tolist(),
359
+ )
360
+ ]
361
+ )
362
+ assert (
363
+ page_table_1_flattened.shape[0] == forward_batch.seq_lens_sum
364
+ ), f"{page_table_1_flattened.shape[0] = } must be the same as {forward_batch.seq_lens_sum = }"
365
+
366
+ topk_indices_offset = torch.repeat_interleave(
367
+ cu_seqlens_k[:-1],
368
+ forward_batch.extend_seq_lens,
369
+ )
313
370
  else:
314
371
  assert False, f"Unsupported {forward_batch.forward_mode = }"
315
372
 
@@ -328,7 +385,9 @@ class NativeSparseAttnBackend(AttentionBackend):
328
385
  max_seq_len_k=max_seqlen_k,
329
386
  cu_seqlens_q=cu_seqlens_q,
330
387
  cu_seqlens_k=cu_seqlens_k,
388
+ seq_lens_sum=forward_batch.seq_lens_sum,
331
389
  page_table_1=page_table,
390
+ page_table_1_flattened=page_table_1_flattened,
332
391
  flashmla_metadata=(
333
392
  self._compute_flashmla_metadata(
334
393
  cache_seqlens=nsa_cache_seqlens_int32,
@@ -344,6 +403,7 @@ class NativeSparseAttnBackend(AttentionBackend):
344
403
  nsa_extend_seq_lens_list=extend_seq_lens_cpu,
345
404
  real_page_table=self._transform_table_1_to_real(page_table),
346
405
  nsa_max_seqlen_q=1,
406
+ topk_indices_offset=topk_indices_offset,
347
407
  )
348
408
 
349
409
  self.forward_metadata = metadata
@@ -396,6 +456,8 @@ class NativeSparseAttnBackend(AttentionBackend):
396
456
  forward_mode: ForwardMode,
397
457
  spec_info: Optional[SpecInput],
398
458
  ):
459
+ self.set_nsa_prefill_impl(forward_batch=None)
460
+
399
461
  """Initialize forward metadata for capturing CUDA graph."""
400
462
  if forward_mode.is_decode_or_idle():
401
463
  # Normal Decode
@@ -586,6 +648,8 @@ class NativeSparseAttnBackend(AttentionBackend):
586
648
  """Initialize forward metadata for replaying CUDA graph."""
587
649
  assert seq_lens_cpu is not None
588
650
 
651
+ self.set_nsa_prefill_impl(forward_batch=None)
652
+
589
653
  seq_lens = seq_lens[:bs]
590
654
  seq_lens_cpu = seq_lens_cpu[:bs]
591
655
  req_pool_indices = req_pool_indices[:bs]
@@ -780,17 +844,31 @@ class NativeSparseAttnBackend(AttentionBackend):
780
844
  q_rope = q_all[:, :, layer.v_head_dim :]
781
845
 
782
846
  # NOTE(dark): here, we use page size = 1
783
-
847
+ topk_transform_method = self.get_topk_transform_method()
784
848
  if NSA_FUSE_TOPK:
785
849
  page_table_1 = topk_indices
786
850
  else:
787
- assert metadata.nsa_extend_seq_lens_list is not None
788
- page_table_1 = transform_index_page_table_prefill(
789
- page_table=metadata.page_table_1,
790
- topk_indices=topk_indices,
791
- extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
792
- page_size=1,
793
- )
851
+ if topk_transform_method == TopkTransformMethod.RAGGED:
852
+ topk_indices_offset = metadata.topk_indices_offset
853
+ assert topk_indices_offset is not None
854
+ mask = topk_indices != -1
855
+ topk_indices_offset = (
856
+ topk_indices_offset.unsqueeze(1)
857
+ if topk_indices_offset.ndim == 1
858
+ else topk_indices_offset
859
+ )
860
+ topk_indices = torch.where(
861
+ mask, topk_indices + topk_indices_offset, topk_indices
862
+ )
863
+ elif topk_transform_method == TopkTransformMethod.PAGED:
864
+ assert metadata.nsa_extend_seq_lens_list is not None
865
+ page_table_1 = transform_index_page_table_prefill(
866
+ page_table=metadata.page_table_1,
867
+ topk_indices=topk_indices,
868
+ extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
869
+ page_size=1,
870
+ )
871
+
794
872
  if NSA_PREFILL_IMPL == "tilelang":
795
873
  if q_rope is not None:
796
874
  q_all = torch.cat([q_nope, q_rope], dim=-1)
@@ -804,6 +882,22 @@ class NativeSparseAttnBackend(AttentionBackend):
804
882
  elif NSA_PREFILL_IMPL == "flashmla_sparse":
805
883
  if q_rope is not None:
806
884
  q_all = torch.cat([q_nope, q_rope], dim=-1)
885
+
886
+ # NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 has no effect here,
887
+ # because the flashmla_sparse kernel doesn't support fp8 compute
888
+ if topk_transform_method == TopkTransformMethod.RAGGED:
889
+ if any(forward_batch.extend_prefix_lens_cpu):
890
+ page_table_1_flattened = (
891
+ self.forward_metadata.page_table_1_flattened
892
+ )
893
+ assert page_table_1_flattened is not None
894
+ kv_cache = dequantize_k_cache_paged(
895
+ kv_cache, page_table_1_flattened
896
+ )
897
+ else:
898
+ kv_cache = torch.cat([k, k_rope], dim=-1)
899
+ page_table_1 = topk_indices
900
+
807
901
  return self._forward_flashmla_sparse(
808
902
  q_all=q_all,
809
903
  kv_cache=kv_cache,
@@ -1004,7 +1098,7 @@ class NativeSparseAttnBackend(AttentionBackend):
1004
1098
  page_table_1: torch.Tensor,
1005
1099
  sm_scale: float,
1006
1100
  ) -> torch.Tensor:
1007
- from flash_mla import flash_mla_sparse_fwd
1101
+ from sgl_kernel.flash_mla import flash_mla_sparse_fwd
1008
1102
 
1009
1103
  o, _, _ = flash_mla_sparse_fwd(
1010
1104
  q=q_all,
@@ -1025,7 +1119,7 @@ class NativeSparseAttnBackend(AttentionBackend):
1025
1119
  metadata: NSAMetadata,
1026
1120
  page_table_1,
1027
1121
  ) -> torch.Tensor:
1028
- from flash_mla import flash_mla_with_kvcache
1122
+ from sgl_kernel.flash_mla import flash_mla_with_kvcache
1029
1123
 
1030
1124
  cache_seqlens = metadata.nsa_cache_seqlens_int32
1031
1125
 
@@ -1121,13 +1215,53 @@ class NativeSparseAttnBackend(AttentionBackend):
1121
1215
  """Get the fill value for sequence length in CUDA graph."""
1122
1216
  return 1
1123
1217
 
1218
+ def set_nsa_prefill_impl(self, forward_batch: Optional[ForwardBatch] = None) -> str:
1219
+ from sglang.srt.utils import is_blackwell
1220
+
1221
+ global NSA_PREFILL_IMPL
1222
+ if self.enable_auto_select_prefill_impl:
1223
+ if self.nsa_kv_cache_store_fp8:
1224
+ if (
1225
+ is_blackwell()
1226
+ and forward_batch is not None
1227
+ and forward_batch.forward_mode == ForwardMode.EXTEND
1228
+ ):
1229
+ total_kv_tokens = forward_batch.seq_lens_sum
1230
+ total_q_tokens = forward_batch.extend_num_tokens
1231
+ # Heuristic based on benchmarking flashmla_kv vs flashmla_sparse + dequantize_k_cache_paged
1232
+ if total_kv_tokens < total_q_tokens * 512:
1233
+ NSA_PREFILL_IMPL = "flashmla_sparse"
1234
+ return
1235
+ NSA_PREFILL_IMPL = "flashmla_kv"
1236
+ else:
1237
+ # bf16 kv cache
1238
+ NSA_PREFILL_IMPL = "flashmla_sparse"
1239
+
1240
+ def get_topk_transform_method(self) -> TopkTransformMethod:
1241
+ """
1242
+ NSA_FUSE_TOPK controls whether to fuse the topk transform into the topk kernel.
1243
+ This method is used to select the topk transform method which can be fused or unfused.
1244
+ """
1245
+ if (
1246
+ # disable for MTP
1247
+ self.nsa_kv_cache_store_fp8
1248
+ and NSA_PREFILL_IMPL == "flashmla_sparse"
1249
+ ):
1250
+ topk_transform_method = TopkTransformMethod.RAGGED
1251
+ else:
1252
+ topk_transform_method = TopkTransformMethod.PAGED
1253
+ return topk_transform_method
1254
+
1124
1255
  def get_indexer_metadata(
1125
1256
  self, layer_id: int, forward_batch: ForwardBatch
1126
1257
  ) -> NSAIndexerMetadata:
1127
- return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
1258
+ return NSAIndexerMetadata(
1259
+ attn_metadata=self.forward_metadata,
1260
+ topk_transform_method=self.get_topk_transform_method(),
1261
+ )
1128
1262
 
1129
1263
  def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
1130
- from flash_mla import get_mla_metadata
1264
+ from sgl_kernel.flash_mla import get_mla_metadata
1131
1265
 
1132
1266
  flashmla_metadata, num_splits = get_mla_metadata(
1133
1267
  cache_seqlens=cache_seqlens,
@@ -92,7 +92,10 @@ class TritonAttnBackend(AttentionBackend):
92
92
  self.num_kv_head = model_runner.model_config.get_num_kv_heads(
93
93
  get_attention_tp_size()
94
94
  )
95
- if model_runner.hybrid_gdn_config is not None:
95
+ if (
96
+ model_runner.hybrid_gdn_config is not None
97
+ or model_runner.kimi_linear_config is not None
98
+ ):
96
99
  # For hybrid linear models, layer_id = 0 may not be full attention
97
100
  self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
98
101
  else:
@@ -488,10 +488,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
488
488
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
489
489
  ]
490
490
 
491
- if (
492
- any(forward_batch.extend_prefix_lens_cpu)
493
- or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
494
- ):
491
+ if any(
492
+ forward_batch.extend_prefix_lens_cpu
493
+ ) or forward_batch.forward_mode.is_draft_extend(include_v2=True):
495
494
  extend_seq_lens = forward_batch.extend_seq_lens
496
495
  metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
497
496
  metadata.cu_seqlens_q = torch.nn.functional.pad(
@@ -529,6 +528,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
529
528
  layer, cache_loc, k, v, layer.k_scale, layer.v_scale
530
529
  )
531
530
 
531
+ if self.data_type == torch.float8_e4m3fn:
532
+ q = q.to(torch.float8_e4m3fn)
532
533
  q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
533
534
  k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
534
535
  # shape conversion:
@@ -567,6 +568,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
567
568
  window_left=layer.sliding_window_size,
568
569
  # TODO: add attention_sink operation or nvfp4 scale factor if needed
569
570
  sinks=attention_sink,
571
+ out_dtype=self.q_data_type, # model_runner.dtype
570
572
  )
571
573
 
572
574
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -586,6 +588,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
586
588
  forward_batch.token_to_kv_pool.set_kv_buffer(
587
589
  layer, cache_loc, k, v, layer.k_scale, layer.v_scale
588
590
  )
591
+
592
+ if self.data_type == torch.float8_e4m3fn:
593
+ q = q.to(torch.float8_e4m3fn)
589
594
  q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
590
595
  # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
591
596
  k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
@@ -625,6 +630,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
625
630
  window_left=layer.sliding_window_size,
626
631
  # TODO: add attention_sink operation or nvfp4 scale factor if needed
627
632
  sinks=attention_sink,
633
+ out_dtype=self.q_data_type, # model_runner.dtype
628
634
  )
629
635
 
630
636
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -423,14 +423,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
423
423
  PAGED_SIZE=self.page_size,
424
424
  )
425
425
 
426
- # Record the true maximum sequence length for this capture batch so that
427
- # the kernel launch path (which requires an int not a tensor) can reuse
428
- # it safely during both capture and replay.
429
- max_seq_len_val = int(seq_lens.max().item())
430
-
431
426
  metadata = TRTLLMMLADecodeMetadata(
432
427
  block_kv_indices,
433
- max_seq_len_val,
428
+ self.max_context_len,
434
429
  )
435
430
  if forward_mode.is_draft_extend(include_v2=True):
436
431
  num_tokens_per_bs = num_tokens // bs
@@ -509,13 +504,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
509
504
  PAGED_SIZE=self.page_size,
510
505
  )
511
506
 
512
- # Update stored max_seq_len so subsequent kernel calls use the correct value
513
- # Prefer CPU tensor to avoid GPU synchronization when available.
514
- if seq_lens_cpu is not None:
515
- metadata.max_seq_len = int(seq_lens_cpu.max().item())
516
- else:
517
- metadata.max_seq_len = int(seq_lens.max().item())
518
-
519
507
  def get_cuda_graph_seq_len_fill_value(self) -> int:
520
508
  """Get the fill value for sequence lengths in CUDA graph."""
521
509
  return 1
@@ -956,8 +944,16 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
956
944
  metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num
957
945
  )
958
946
  else:
959
- seq_lens = forward_batch.seq_lens.to(torch.int32)
960
- max_seq_len = metadata.max_seq_len_k
947
+ # forward_batch.seq_lens is the seq_lens of the prev_context + verified tokens.
948
+ # To account for pad_draft_extend_query, we need seq_lens = prev_context + max_draft_tokens.
949
+ # This will ensure queries align with kvs correctly when calling
950
+ # flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla.
951
+ seq_lens = (
952
+ forward_batch.seq_lens
953
+ - metadata.seq_lens_q
954
+ + metadata.max_seq_len_q
955
+ ).to(torch.int32)
956
+ max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q
961
957
  # Check if we're in CUDA graph mode (buffers are pre-allocated)
962
958
  if self.padded_q_buffer is not None:
963
959
  # Use pre-allocated buffer for CUDA graph compatibility