sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,9 @@ import triton
12
12
  import triton.language as tl
13
13
  from einops import rearrange
14
14
 
15
- from sglang.srt.utils import device_context
15
+ from sglang.srt.utils import device_context, is_npu
16
+
17
+ _is_npu = is_npu()
16
18
 
17
19
 
18
20
  def rms_norm_ref(
@@ -182,6 +184,10 @@ def _layer_norm_fwd(
182
184
  return out, mean, rstd
183
185
 
184
186
 
187
+ if _is_npu:
188
+ from sgl_kernel_npu.fla.layernorm_gated import layer_norm_fwd_npu as _layer_norm_fwd
189
+
190
+
185
191
  def rms_norm_gated(
186
192
  *,
187
193
  x,
@@ -584,7 +584,9 @@ class FlashAttentionBackend(AttentionBackend):
584
584
  metadata, metadata_expand
585
585
  )
586
586
 
587
- elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
587
+ elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(
588
+ include_draft_extend_v2=True
589
+ ):
588
590
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
589
591
  metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
590
592
  metadata.cu_seqlens_k = torch.nn.functional.pad(
@@ -594,10 +596,9 @@ class FlashAttentionBackend(AttentionBackend):
594
596
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
595
597
  ]
596
598
 
597
- if (
598
- any(forward_batch.extend_prefix_lens_cpu)
599
- or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
600
- ):
599
+ if any(
600
+ forward_batch.extend_prefix_lens_cpu
601
+ ) or forward_batch.forward_mode.is_draft_extend(include_v2=True):
601
602
  extend_seq_lens = forward_batch.extend_seq_lens
602
603
  metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
603
604
  metadata.cu_seqlens_q = torch.nn.functional.pad(
@@ -826,7 +827,7 @@ class FlashAttentionBackend(AttentionBackend):
826
827
  if (
827
828
  forward_batch.attn_attend_prefix_cache is not None
828
829
  and not forward_batch.forward_mode.is_target_verify()
829
- and not forward_batch.forward_mode.is_draft_extend()
830
+ and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
830
831
  ):
831
832
  # Do multi-head attention with chunked prefix cache
832
833
  if forward_batch.attn_attend_prefix_cache:
@@ -855,14 +856,24 @@ class FlashAttentionBackend(AttentionBackend):
855
856
  )
856
857
  else:
857
858
  # MHA for extend part of sequence without attending prefix kv cache
859
+ cu_seqlens_k = (
860
+ metadata.cu_seqlens_q
861
+ if not forward_batch.mha_one_shot
862
+ else metadata.cu_seqlens_k
863
+ )
864
+ max_seqlen_k = (
865
+ metadata.max_seq_len_q
866
+ if not forward_batch.mha_one_shot
867
+ else metadata.max_seq_len_k
868
+ )
858
869
  output = flash_attn_varlen_func(
859
870
  q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
860
871
  k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
861
872
  v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
862
873
  cu_seqlens_q=metadata.cu_seqlens_q,
863
- cu_seqlens_k=metadata.cu_seqlens_q,
874
+ cu_seqlens_k=cu_seqlens_k,
864
875
  max_seqlen_q=metadata.max_seq_len_q,
865
- max_seqlen_k=metadata.max_seq_len_q,
876
+ max_seqlen_k=max_seqlen_k,
866
877
  softmax_scale=layer.scaling,
867
878
  causal=True,
868
879
  return_softmax_lse=forward_batch.mha_return_lse,
@@ -230,7 +230,16 @@ class FlashInferAttnBackend(AttentionBackend):
230
230
 
231
231
  fmha_backend = "auto"
232
232
  if is_sm100_supported():
233
- fmha_backend = "cutlass"
233
+ # Disable CUTLASS backend when piecewise cuda graph is enabled
234
+ # due to TMA descriptor initialization issues on B200
235
+ if model_runner.server_args.enable_piecewise_cuda_graph:
236
+ logger.warning(
237
+ "CUTLASS backend is disabled when piecewise cuda graph is enabled "
238
+ "due to TMA descriptor initialization issues on B200. "
239
+ "Using auto backend instead for stability."
240
+ )
241
+ else:
242
+ fmha_backend = "cutlass"
234
243
  self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
235
244
  self.workspace_buffer, "NHD", backend=fmha_backend
236
245
  )
@@ -82,6 +82,7 @@ class FlashInferMhaChunkKVRunner:
82
82
 
83
83
  # Buffers and wrappers
84
84
  self.qo_indptr = attn_backend.qo_indptr
85
+ self.kv_indptr = attn_backend.kv_indptr
85
86
  self.workspace_buffer = attn_backend.workspace_buffer
86
87
  self.fmha_backend = attn_backend.fmha_backend
87
88
 
@@ -132,9 +133,14 @@ class FlashInferMhaChunkKVRunner:
132
133
  )
133
134
  # ragged prefill
134
135
  if not disable_flashinfer_ragged:
136
+ kv_indptr = (
137
+ qo_indptr
138
+ if not forward_batch.mha_one_shot
139
+ else self.kv_indptr[: bs + 1]
140
+ )
135
141
  self.ragged_wrapper.begin_forward(
136
142
  qo_indptr=qo_indptr,
137
- kv_indptr=qo_indptr,
143
+ kv_indptr=kv_indptr,
138
144
  num_qo_heads=self.num_local_heads,
139
145
  num_kv_heads=self.num_local_heads,
140
146
  head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
@@ -156,7 +162,7 @@ class FlashInferMhaChunkKVRunner:
156
162
  chunk_idx = forward_batch.prefix_chunk_idx
157
163
  assert chunk_idx >= 0
158
164
  wrapper = self.chunk_ragged_wrappers[chunk_idx]
159
- o1, s1 = wrapper.forward_return_lse(
165
+ o = wrapper.forward_return_lse(
160
166
  q.view(-1, layer.tp_q_head_num, layer.head_dim),
161
167
  k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
162
168
  v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
@@ -165,7 +171,12 @@ class FlashInferMhaChunkKVRunner:
165
171
  logits_soft_cap=logits_soft_cap,
166
172
  )
167
173
  else:
168
- o1, s1 = self.ragged_wrapper.forward_return_lse(
174
+ forward = (
175
+ self.ragged_wrapper.forward_return_lse
176
+ if forward_batch.mha_return_lse
177
+ else self.ragged_wrapper.forward
178
+ )
179
+ o = forward(
169
180
  q.view(-1, layer.tp_q_head_num, layer.head_dim),
170
181
  k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
171
182
  v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
@@ -173,8 +184,7 @@ class FlashInferMhaChunkKVRunner:
173
184
  sm_scale=layer.scaling,
174
185
  logits_soft_cap=logits_soft_cap,
175
186
  )
176
-
177
- return o1, s1
187
+ return o
178
188
 
179
189
 
180
190
  class FlashInferMLAAttnBackend(AttentionBackend):
@@ -232,9 +242,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
232
242
  else:
233
243
  self.q_indptr_decode = q_indptr_decode_buf
234
244
 
235
- self.fmha_backend = "auto"
236
245
  if is_sm100_supported():
237
246
  self.fmha_backend = "cutlass"
247
+ else:
248
+ self.fmha_backend = "auto"
249
+
238
250
  self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
239
251
  self.workspace_buffer, "NHD", backend=self.fmha_backend
240
252
  )
@@ -512,15 +524,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
512
524
  q_rope: Optional[torch.Tensor] = None,
513
525
  k_rope: Optional[torch.Tensor] = None,
514
526
  ):
515
- if (
516
- forward_batch.attn_attend_prefix_cache is not None
517
- and forward_batch.mha_return_lse
527
+ if forward_batch.attn_attend_prefix_cache is not None and any(
528
+ forward_batch.extend_prefix_lens_cpu
518
529
  ): # MHA Chunk
519
530
  assert self.enable_chunk_kv
520
531
  assert q_rope is None
521
532
  assert k_rope is None
522
- o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
523
- return o1, s1
533
+ return self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
524
534
 
525
535
  cache_loc = forward_batch.out_cache_loc
526
536
  logits_soft_cap = layer.logit_cap
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
9
9
 
10
10
  import torch
11
11
  import triton
12
- from flash_mla import flash_mla_with_kvcache, get_mla_metadata
12
+ from sgl_kernel.flash_mla import flash_mla_with_kvcache, get_mla_metadata
13
13
 
14
14
  from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
15
15
  from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
@@ -1,6 +1,7 @@
1
1
  from typing import Optional, Union
2
2
 
3
3
  import torch
4
+ from einops import rearrange
4
5
 
5
6
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
6
7
  from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
@@ -10,6 +11,11 @@ from sglang.srt.layers.attention.fla.fused_recurrent import (
10
11
  from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
11
12
  fused_sigmoid_gating_delta_rule_update,
12
13
  )
14
+ from sglang.srt.layers.attention.fla.kda import (
15
+ chunk_kda,
16
+ fused_kda_gate,
17
+ fused_recurrent_kda,
18
+ )
13
19
  from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
14
20
  PAD_SLOT_ID,
15
21
  causal_conv1d_fn,
@@ -227,6 +233,223 @@ class MambaAttnBackendBase(AttentionBackend):
227
233
  return 1 # Mamba attn does not use seq lens to index kv cache
228
234
 
229
235
 
236
+ class KimiLinearAttnBackend(MambaAttnBackendBase):
237
+ """Attention backend using Mamba kernel."""
238
+
239
+ def forward_decode(
240
+ self,
241
+ q: torch.Tensor,
242
+ k: torch.Tensor,
243
+ v: torch.Tensor,
244
+ layer: RadixAttention,
245
+ forward_batch: ForwardBatch,
246
+ save_kv_cache: bool = True,
247
+ **kwargs,
248
+ ):
249
+ q_proj_states = kwargs["q_proj_states"]
250
+ k_proj_states = kwargs["k_proj_states"]
251
+ v_proj_states = kwargs["v_proj_states"]
252
+ q_conv_weights = kwargs["q_conv_weights"]
253
+ k_conv_weights = kwargs["k_conv_weights"]
254
+ v_conv_weights = kwargs["v_conv_weights"]
255
+
256
+ q_conv_bias = kwargs["q_conv_bias"]
257
+ k_conv_bias = kwargs["k_conv_bias"]
258
+ v_conv_bias = kwargs["v_conv_bias"]
259
+
260
+ A_log = kwargs["A_log"]
261
+ dt_bias = kwargs["dt_bias"]
262
+ b_proj = kwargs["b_proj"]
263
+ f_a_proj = kwargs["f_a_proj"]
264
+ f_b_proj = kwargs["f_b_proj"]
265
+ hidden_states = kwargs["hidden_states"]
266
+ head_dim = kwargs["head_dim"]
267
+ layer_id = kwargs["layer_id"]
268
+
269
+ layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
270
+ q_conv_state, k_conv_state, v_conv_state = layer_cache.conv
271
+ ssm_states = layer_cache.temporal
272
+ query_start_loc = self.forward_metadata.query_start_loc
273
+ cache_indices = self.forward_metadata.mamba_cache_indices
274
+
275
+ q_conv_state = q_conv_state.transpose(-1, -2)
276
+ k_conv_state = k_conv_state.transpose(-1, -2)
277
+ v_conv_state = v_conv_state.transpose(-1, -2)
278
+
279
+ q = causal_conv1d_update(
280
+ q_proj_states,
281
+ q_conv_state,
282
+ q_conv_weights,
283
+ q_conv_bias,
284
+ activation="silu",
285
+ conv_state_indices=cache_indices,
286
+ )
287
+ k = causal_conv1d_update(
288
+ k_proj_states,
289
+ k_conv_state,
290
+ k_conv_weights,
291
+ k_conv_bias,
292
+ activation="silu",
293
+ conv_state_indices=cache_indices,
294
+ )
295
+ v = causal_conv1d_update(
296
+ v_proj_states,
297
+ v_conv_state,
298
+ v_conv_weights,
299
+ v_conv_bias,
300
+ activation="silu",
301
+ conv_state_indices=cache_indices,
302
+ )
303
+
304
+ q, k, v = map(
305
+ lambda x: rearrange(x, "n (h d) -> 1 n h d", d=head_dim), (q, k, v)
306
+ )
307
+
308
+ beta = b_proj(hidden_states)[0].float().sigmoid()
309
+
310
+ g = f_b_proj(f_a_proj(hidden_states)[0])[0]
311
+ g = fused_kda_gate(g, A_log, head_dim, g_bias=dt_bias)
312
+
313
+ beta = beta.unsqueeze(0)
314
+ g = g.unsqueeze(0)
315
+
316
+ initial_state = ssm_states[cache_indices].contiguous()
317
+ (
318
+ core_attn_out,
319
+ last_recurrent_state,
320
+ ) = fused_recurrent_kda(
321
+ q=q,
322
+ k=k,
323
+ v=v,
324
+ g=g,
325
+ beta=beta,
326
+ initial_state=initial_state,
327
+ use_qk_l2norm_in_kernel=True,
328
+ cu_seqlens=query_start_loc,
329
+ )
330
+ ssm_states[cache_indices] = last_recurrent_state
331
+ return core_attn_out
332
+
333
+ def forward_extend(
334
+ self,
335
+ q: torch.Tensor,
336
+ k: torch.Tensor,
337
+ v: torch.Tensor,
338
+ layer: RadixAttention,
339
+ forward_batch: ForwardBatch,
340
+ save_kv_cache: bool = True,
341
+ **kwargs,
342
+ ):
343
+ from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
344
+ causal_conv1d_fn,
345
+ )
346
+
347
+ q_proj_states = kwargs["q_proj_states"]
348
+ k_proj_states = kwargs["k_proj_states"]
349
+ v_proj_states = kwargs["v_proj_states"]
350
+ q_conv_weights = kwargs["q_conv_weights"]
351
+ k_conv_weights = kwargs["k_conv_weights"]
352
+ v_conv_weights = kwargs["v_conv_weights"]
353
+
354
+ q_conv_bias = kwargs["q_conv_bias"]
355
+ k_conv_bias = kwargs["k_conv_bias"]
356
+ v_conv_bias = kwargs["v_conv_bias"]
357
+
358
+ A_log = kwargs["A_log"]
359
+ dt_bias = kwargs["dt_bias"]
360
+ b_proj = kwargs["b_proj"]
361
+ f_a_proj = kwargs["f_a_proj"]
362
+ f_b_proj = kwargs["f_b_proj"]
363
+ hidden_states = kwargs["hidden_states"]
364
+ head_dim = kwargs["head_dim"]
365
+ layer_id = kwargs["layer_id"]
366
+
367
+ query_start_loc = self.forward_metadata.query_start_loc
368
+ cache_indices = self.forward_metadata.mamba_cache_indices
369
+
370
+ mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
371
+ conv_state_q, conv_state_k, conv_state_v = mamba_cache_params.conv
372
+ # deal with strides
373
+ conv_state_q = conv_state_q.transpose(-1, -2)
374
+ conv_state_k = conv_state_k.transpose(-1, -2)
375
+ conv_state_v = conv_state_v.transpose(-1, -2)
376
+
377
+ ssm_states = mamba_cache_params.temporal
378
+
379
+ has_initial_state = forward_batch.extend_prefix_lens > 0
380
+
381
+ q_proj_states = q_proj_states.transpose(0, 1)
382
+ k_proj_states = k_proj_states.transpose(0, 1)
383
+ v_proj_states = v_proj_states.transpose(0, 1)
384
+
385
+ q = causal_conv1d_fn(
386
+ q_proj_states,
387
+ q_conv_weights,
388
+ q_conv_bias,
389
+ activation="silu",
390
+ conv_states=conv_state_q,
391
+ has_initial_state=has_initial_state,
392
+ cache_indices=cache_indices,
393
+ query_start_loc=query_start_loc,
394
+ seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
395
+ ).transpose(0, 1)
396
+
397
+ k = causal_conv1d_fn(
398
+ k_proj_states,
399
+ k_conv_weights,
400
+ k_conv_bias,
401
+ activation="silu",
402
+ conv_states=conv_state_k,
403
+ has_initial_state=has_initial_state,
404
+ cache_indices=cache_indices,
405
+ query_start_loc=query_start_loc,
406
+ seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
407
+ ).transpose(0, 1)
408
+
409
+ v = causal_conv1d_fn(
410
+ v_proj_states,
411
+ v_conv_weights,
412
+ v_conv_bias,
413
+ activation="silu",
414
+ conv_states=conv_state_v,
415
+ has_initial_state=has_initial_state,
416
+ cache_indices=cache_indices,
417
+ query_start_loc=query_start_loc,
418
+ seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
419
+ ).transpose(0, 1)
420
+
421
+ q, k, v = map(
422
+ lambda x: rearrange(x, "n (h d) -> 1 n h d", d=head_dim), (q, k, v)
423
+ )
424
+
425
+ beta = b_proj(hidden_states)[0].float().sigmoid()
426
+
427
+ g = f_b_proj(f_a_proj(hidden_states)[0])[0]
428
+ g = fused_kda_gate(g, A_log, head_dim, g_bias=dt_bias)
429
+
430
+ beta = beta.unsqueeze(0)
431
+ g = g.unsqueeze(0)
432
+
433
+ initial_state = ssm_states[cache_indices].contiguous()
434
+ (
435
+ core_attn_out,
436
+ last_recurrent_state,
437
+ ) = chunk_kda(
438
+ q=q,
439
+ k=k,
440
+ v=v,
441
+ g=g,
442
+ beta=beta,
443
+ initial_state=initial_state,
444
+ output_final_state=True,
445
+ use_qk_l2norm_in_kernel=True,
446
+ cu_seqlens=query_start_loc,
447
+ )
448
+ ssm_states[cache_indices] = last_recurrent_state
449
+
450
+ return core_attn_out
451
+
452
+
230
453
  class GDNAttnBackend(MambaAttnBackendBase):
231
454
  """Attention backend using Mamba kernel."""
232
455
 
@@ -13,16 +13,6 @@ from sglang.srt.distributed import (
13
13
  get_tensor_model_parallel_world_size,
14
14
  )
15
15
  from sglang.srt.distributed.utils import divide
16
- from sglang.srt.layers.attention.mamba.causal_conv1d import (
17
- causal_conv1d_fn,
18
- causal_conv1d_update,
19
- )
20
- from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
21
- causal_conv1d_fn as causal_conv1d_fn_triton,
22
- )
23
- from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
24
- causal_conv1d_update as causal_conv1d_update_triton,
25
- )
26
16
  from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
27
17
  from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated
28
18
  from sglang.srt.layers.attention.mamba.ops import (
@@ -40,7 +30,26 @@ from sglang.srt.model_loader.weight_utils import (
40
30
  composed_weight_loader,
41
31
  sharded_weight_loader,
42
32
  )
43
- from sglang.srt.utils import set_weight_attrs
33
+ from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
34
+
35
+ if is_cuda():
36
+ from sglang.srt.layers.attention.mamba.causal_conv1d import (
37
+ causal_conv1d_fn,
38
+ causal_conv1d_update,
39
+ )
40
+ from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
41
+ causal_conv1d_fn as causal_conv1d_fn_triton,
42
+ )
43
+ from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
44
+ causal_conv1d_update as causal_conv1d_update_triton,
45
+ )
46
+ elif is_npu():
47
+ from sgl_kernel_npu.mamba.causal_conv1d import (
48
+ causal_conv1d_fn_npu as causal_conv1d_fn,
49
+ )
50
+ from sgl_kernel_npu.mamba.causal_conv1d import (
51
+ causal_conv1d_update_npu as causal_conv1d_update,
52
+ )
44
53
 
45
54
  LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
46
55
 
@@ -22,6 +22,10 @@ def _dequantize_k_cache_slow(
22
22
  De-quantize the k-cache
23
23
  """
24
24
  assert dv % tile_size == 0
25
+ original_ndim = quant_k_cache.ndim
26
+ if original_ndim == 3:
27
+ # set block_size = 1
28
+ quant_k_cache = quant_k_cache.unsqueeze(1)
25
29
  num_tiles = dv // tile_size
26
30
  num_blocks, block_size, h_k, _ = quant_k_cache.shape
27
31
  assert h_k == 1
@@ -45,8 +49,10 @@ def _dequantize_k_cache_slow(
45
49
  cur_nope * cur_scales
46
50
  )
47
51
 
48
- result = result.view(num_blocks, block_size, 1, d)
49
- return result
52
+ if original_ndim == 3:
53
+ return result.view(num_blocks, 1, -1)
54
+ else:
55
+ return result.view(num_blocks, block_size, 1, -1)
50
56
 
51
57
 
52
58
  def _dequantize_k_cache_fast_wrapped(
@@ -54,7 +60,10 @@ def _dequantize_k_cache_fast_wrapped(
54
60
  dv: int = 512,
55
61
  tile_size: int = 128,
56
62
  ) -> torch.Tensor:
57
- # TODO the final API may be 2D instead of 4D, thus we convert them here
63
+ original_ndim = quant_k_cache.ndim
64
+ if original_ndim == 3:
65
+ # set block_size = 1
66
+ quant_k_cache = quant_k_cache.unsqueeze(1)
58
67
  num_blocks, block_size, _, dim_quant = quant_k_cache.shape
59
68
  assert dv == 512
60
69
  assert dim_quant == 656
@@ -63,7 +72,10 @@ def _dequantize_k_cache_fast_wrapped(
63
72
 
64
73
  output = _dequantize_k_cache_fast(quant_k_cache)
65
74
 
66
- return output.view(num_blocks, block_size, 1, -1)
75
+ if original_ndim == 3:
76
+ return output.view(num_blocks, 1, -1)
77
+ else:
78
+ return output.view(num_blocks, block_size, 1, -1)
67
79
 
68
80
 
69
81
  def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
@@ -85,7 +97,6 @@ def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
85
97
  assert num_blocks_per_token == 5
86
98
 
87
99
  assert dim_nope % group_size == 0
88
- NUM_NOPE_BLOCKS = dim_nope // group_size
89
100
 
90
101
  input_nope_q = quant_k_cache[:, :dim_nope]
91
102
  input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
@@ -102,7 +113,7 @@ def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
102
113
  input_nope_q.stride(0),
103
114
  input_nope_s.stride(0),
104
115
  input_rope.stride(0),
105
- NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
116
+ NUM_NOPE_BLOCKS=num_tiles,
106
117
  GROUP_SIZE=group_size,
107
118
  DIM_NOPE=dim_nope,
108
119
  DIM_ROPE=dim_rope,
@@ -159,5 +170,126 @@ def _dequantize_k_cache_fast_kernel(
159
170
  tl.store(dst_ptr, data, mask=mask)
160
171
 
161
172
 
173
+ def dequantize_k_cache_paged(
174
+ quant_k_cache: torch.Tensor,
175
+ page_table_1_flattened: torch.Tensor,
176
+ group_size: int = 128,
177
+ ) -> torch.Tensor:
178
+ """
179
+ De-quantize the k-cache with paged layout
180
+ Args:
181
+ quant_k_cache: [total_num_tokens, 1, dim_quant] or [num_blocks, block_size, 1, dim_quant], the quantized k-cache in paged layout
182
+ page_table_1_flattened: [num_tokens], the flattened page_table_1 with the page indices in each requests concatenated together
183
+ Returns:
184
+ output: [num_tokens, 1, dim_nope + dim_rope], the de-quantized k-cache
185
+ """
186
+ dim_quant = quant_k_cache.shape[-1]
187
+ assert (
188
+ dim_quant == 656
189
+ ), f"dim_quant: {dim_quant} != 656 detected in dequantize_k_cache_paged"
190
+ quant_k_cache = quant_k_cache.view((-1, dim_quant))
191
+
192
+ total_num_tokens, _ = quant_k_cache.shape
193
+ num_tokens = page_table_1_flattened.shape[0]
194
+ assert num_tokens <= total_num_tokens
195
+
196
+ assert quant_k_cache.dtype == torch.float8_e4m3fn
197
+ dim_nope = 512
198
+ dim_rope = 64
199
+ num_tiles = dim_nope // group_size # 512 // 128 = 4
200
+
201
+ output = torch.empty(
202
+ (num_tokens, 1, dim_nope + dim_rope),
203
+ dtype=torch.bfloat16,
204
+ device=quant_k_cache.device,
205
+ )
206
+
207
+ # cdiv(512 + 64, 128) = 5
208
+ num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
209
+ assert num_blocks_per_token == 5
210
+
211
+ assert dim_nope % group_size == 0
212
+
213
+ input_nope_q = quant_k_cache[:, :dim_nope]
214
+ # [:, 512:512+4*4] = [:, 512:528]
215
+ input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
216
+ torch.float32
217
+ )
218
+ # [:, 528:]
219
+ input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16)
220
+
221
+ _dequantize_k_cache_paged_kernel[(num_tokens, num_blocks_per_token)](
222
+ output,
223
+ input_nope_q,
224
+ input_nope_s,
225
+ input_rope,
226
+ page_table_1_flattened,
227
+ output.stride(0),
228
+ input_nope_q.stride(0),
229
+ input_nope_s.stride(0),
230
+ input_rope.stride(0),
231
+ NUM_NOPE_BLOCKS=num_tiles,
232
+ GROUP_SIZE=group_size,
233
+ DIM_NOPE=dim_nope,
234
+ DIM_ROPE=dim_rope,
235
+ )
236
+
237
+ return output
238
+
239
+
240
+ @triton.jit
241
+ def _dequantize_k_cache_paged_kernel(
242
+ output_ptr,
243
+ input_nope_q_ptr,
244
+ input_nope_s_ptr,
245
+ input_rope_ptr,
246
+ page_table_1_ptr,
247
+ output_stride_0: int,
248
+ input_nope_q_stride_0: int,
249
+ input_nope_s_stride_0: int,
250
+ input_rope_stride_0: int,
251
+ NUM_NOPE_BLOCKS: tl.constexpr,
252
+ GROUP_SIZE: tl.constexpr,
253
+ DIM_NOPE: tl.constexpr,
254
+ DIM_ROPE: tl.constexpr,
255
+ ):
256
+ token_id = tl.program_id(0)
257
+ token_id_paged = tl.load(page_table_1_ptr + token_id).to(tl.int32)
258
+ raw_block_id = tl.program_id(1)
259
+
260
+ if raw_block_id < NUM_NOPE_BLOCKS:
261
+ # a. dequant nope
262
+ effective_block_id = raw_block_id
263
+
264
+ offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
265
+ mask = offs_q < DIM_NOPE
266
+ ptr_q = input_nope_q_ptr + token_id_paged * input_nope_q_stride_0 + offs_q
267
+ ptr_s = (
268
+ input_nope_s_ptr
269
+ + token_id_paged * input_nope_s_stride_0
270
+ + effective_block_id
271
+ )
272
+
273
+ y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32)
274
+ y_s = tl.load(ptr_s)
275
+
276
+ y = (y_q * y_s).to(output_ptr.dtype.element_ty)
277
+
278
+ dst_ptr = output_ptr + token_id * output_stride_0 + offs_q
279
+ tl.store(dst_ptr, y, mask=mask)
280
+ else:
281
+ # b. copy rope
282
+ effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
283
+
284
+ offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
285
+ mask = offs < DIM_ROPE
286
+
287
+ src_ptr = input_rope_ptr + token_id_paged * input_rope_stride_0 + offs
288
+ dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs
289
+
290
+ data = tl.load(src_ptr, mask=mask).to(tl.bfloat16)
291
+ tl.store(dst_ptr, data, mask=mask)
292
+
293
+
162
294
  if __name__ == "__main__":
163
295
  raise Exception("UT is in quant_k_cache.py")