sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,9 @@ import triton
12
12
  import triton.language as tl
13
13
  from einops import rearrange
14
14
 
15
- from sglang.srt.utils import device_context
15
+ from sglang.srt.utils import device_context, is_npu
16
+
17
+ _is_npu = is_npu()
16
18
 
17
19
 
18
20
  def rms_norm_ref(
@@ -182,6 +184,10 @@ def _layer_norm_fwd(
182
184
  return out, mean, rstd
183
185
 
184
186
 
187
+ if _is_npu:
188
+ from sgl_kernel_npu.fla.layernorm_gated import layer_norm_fwd_npu as _layer_norm_fwd
189
+
190
+
185
191
  def rms_norm_gated(
186
192
  *,
187
193
  x,
@@ -584,7 +584,9 @@ class FlashAttentionBackend(AttentionBackend):
584
584
  metadata, metadata_expand
585
585
  )
586
586
 
587
- elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
587
+ elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(
588
+ include_draft_extend_v2=True
589
+ ):
588
590
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
589
591
  metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
590
592
  metadata.cu_seqlens_k = torch.nn.functional.pad(
@@ -594,10 +596,9 @@ class FlashAttentionBackend(AttentionBackend):
594
596
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
595
597
  ]
596
598
 
597
- if (
598
- any(forward_batch.extend_prefix_lens_cpu)
599
- or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
600
- ):
599
+ if any(
600
+ forward_batch.extend_prefix_lens_cpu
601
+ ) or forward_batch.forward_mode.is_draft_extend(include_v2=True):
601
602
  extend_seq_lens = forward_batch.extend_seq_lens
602
603
  metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
603
604
  metadata.cu_seqlens_q = torch.nn.functional.pad(
@@ -826,7 +827,7 @@ class FlashAttentionBackend(AttentionBackend):
826
827
  if (
827
828
  forward_batch.attn_attend_prefix_cache is not None
828
829
  and not forward_batch.forward_mode.is_target_verify()
829
- and not forward_batch.forward_mode.is_draft_extend()
830
+ and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
830
831
  ):
831
832
  # Do multi-head attention with chunked prefix cache
832
833
  if forward_batch.attn_attend_prefix_cache:
@@ -242,9 +242,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
242
242
  else:
243
243
  self.q_indptr_decode = q_indptr_decode_buf
244
244
 
245
- self.fmha_backend = "auto"
246
245
  if is_sm100_supported():
247
246
  self.fmha_backend = "cutlass"
247
+ else:
248
+ self.fmha_backend = "auto"
249
+
248
250
  self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
249
251
  self.workspace_buffer, "NHD", backend=self.fmha_backend
250
252
  )
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
9
9
 
10
10
  import torch
11
11
  import triton
12
- from flash_mla import flash_mla_with_kvcache, get_mla_metadata
12
+ from sgl_kernel.flash_mla import flash_mla_with_kvcache, get_mla_metadata
13
13
 
14
14
  from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
15
15
  from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
@@ -1,6 +1,7 @@
1
1
  from typing import Optional, Union
2
2
 
3
3
  import torch
4
+ from einops import rearrange
4
5
 
5
6
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
6
7
  from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
@@ -10,6 +11,11 @@ from sglang.srt.layers.attention.fla.fused_recurrent import (
10
11
  from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
11
12
  fused_sigmoid_gating_delta_rule_update,
12
13
  )
14
+ from sglang.srt.layers.attention.fla.kda import (
15
+ chunk_kda,
16
+ fused_kda_gate,
17
+ fused_recurrent_kda,
18
+ )
13
19
  from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
14
20
  PAD_SLOT_ID,
15
21
  causal_conv1d_fn,
@@ -227,6 +233,223 @@ class MambaAttnBackendBase(AttentionBackend):
227
233
  return 1 # Mamba attn does not use seq lens to index kv cache
228
234
 
229
235
 
236
+ class KimiLinearAttnBackend(MambaAttnBackendBase):
237
+ """Attention backend using Mamba kernel."""
238
+
239
+ def forward_decode(
240
+ self,
241
+ q: torch.Tensor,
242
+ k: torch.Tensor,
243
+ v: torch.Tensor,
244
+ layer: RadixAttention,
245
+ forward_batch: ForwardBatch,
246
+ save_kv_cache: bool = True,
247
+ **kwargs,
248
+ ):
249
+ q_proj_states = kwargs["q_proj_states"]
250
+ k_proj_states = kwargs["k_proj_states"]
251
+ v_proj_states = kwargs["v_proj_states"]
252
+ q_conv_weights = kwargs["q_conv_weights"]
253
+ k_conv_weights = kwargs["k_conv_weights"]
254
+ v_conv_weights = kwargs["v_conv_weights"]
255
+
256
+ q_conv_bias = kwargs["q_conv_bias"]
257
+ k_conv_bias = kwargs["k_conv_bias"]
258
+ v_conv_bias = kwargs["v_conv_bias"]
259
+
260
+ A_log = kwargs["A_log"]
261
+ dt_bias = kwargs["dt_bias"]
262
+ b_proj = kwargs["b_proj"]
263
+ f_a_proj = kwargs["f_a_proj"]
264
+ f_b_proj = kwargs["f_b_proj"]
265
+ hidden_states = kwargs["hidden_states"]
266
+ head_dim = kwargs["head_dim"]
267
+ layer_id = kwargs["layer_id"]
268
+
269
+ layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
270
+ q_conv_state, k_conv_state, v_conv_state = layer_cache.conv
271
+ ssm_states = layer_cache.temporal
272
+ query_start_loc = self.forward_metadata.query_start_loc
273
+ cache_indices = self.forward_metadata.mamba_cache_indices
274
+
275
+ q_conv_state = q_conv_state.transpose(-1, -2)
276
+ k_conv_state = k_conv_state.transpose(-1, -2)
277
+ v_conv_state = v_conv_state.transpose(-1, -2)
278
+
279
+ q = causal_conv1d_update(
280
+ q_proj_states,
281
+ q_conv_state,
282
+ q_conv_weights,
283
+ q_conv_bias,
284
+ activation="silu",
285
+ conv_state_indices=cache_indices,
286
+ )
287
+ k = causal_conv1d_update(
288
+ k_proj_states,
289
+ k_conv_state,
290
+ k_conv_weights,
291
+ k_conv_bias,
292
+ activation="silu",
293
+ conv_state_indices=cache_indices,
294
+ )
295
+ v = causal_conv1d_update(
296
+ v_proj_states,
297
+ v_conv_state,
298
+ v_conv_weights,
299
+ v_conv_bias,
300
+ activation="silu",
301
+ conv_state_indices=cache_indices,
302
+ )
303
+
304
+ q, k, v = map(
305
+ lambda x: rearrange(x, "n (h d) -> 1 n h d", d=head_dim), (q, k, v)
306
+ )
307
+
308
+ beta = b_proj(hidden_states)[0].float().sigmoid()
309
+
310
+ g = f_b_proj(f_a_proj(hidden_states)[0])[0]
311
+ g = fused_kda_gate(g, A_log, head_dim, g_bias=dt_bias)
312
+
313
+ beta = beta.unsqueeze(0)
314
+ g = g.unsqueeze(0)
315
+
316
+ initial_state = ssm_states[cache_indices].contiguous()
317
+ (
318
+ core_attn_out,
319
+ last_recurrent_state,
320
+ ) = fused_recurrent_kda(
321
+ q=q,
322
+ k=k,
323
+ v=v,
324
+ g=g,
325
+ beta=beta,
326
+ initial_state=initial_state,
327
+ use_qk_l2norm_in_kernel=True,
328
+ cu_seqlens=query_start_loc,
329
+ )
330
+ ssm_states[cache_indices] = last_recurrent_state
331
+ return core_attn_out
332
+
333
+ def forward_extend(
334
+ self,
335
+ q: torch.Tensor,
336
+ k: torch.Tensor,
337
+ v: torch.Tensor,
338
+ layer: RadixAttention,
339
+ forward_batch: ForwardBatch,
340
+ save_kv_cache: bool = True,
341
+ **kwargs,
342
+ ):
343
+ from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
344
+ causal_conv1d_fn,
345
+ )
346
+
347
+ q_proj_states = kwargs["q_proj_states"]
348
+ k_proj_states = kwargs["k_proj_states"]
349
+ v_proj_states = kwargs["v_proj_states"]
350
+ q_conv_weights = kwargs["q_conv_weights"]
351
+ k_conv_weights = kwargs["k_conv_weights"]
352
+ v_conv_weights = kwargs["v_conv_weights"]
353
+
354
+ q_conv_bias = kwargs["q_conv_bias"]
355
+ k_conv_bias = kwargs["k_conv_bias"]
356
+ v_conv_bias = kwargs["v_conv_bias"]
357
+
358
+ A_log = kwargs["A_log"]
359
+ dt_bias = kwargs["dt_bias"]
360
+ b_proj = kwargs["b_proj"]
361
+ f_a_proj = kwargs["f_a_proj"]
362
+ f_b_proj = kwargs["f_b_proj"]
363
+ hidden_states = kwargs["hidden_states"]
364
+ head_dim = kwargs["head_dim"]
365
+ layer_id = kwargs["layer_id"]
366
+
367
+ query_start_loc = self.forward_metadata.query_start_loc
368
+ cache_indices = self.forward_metadata.mamba_cache_indices
369
+
370
+ mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
371
+ conv_state_q, conv_state_k, conv_state_v = mamba_cache_params.conv
372
+ # deal with strides
373
+ conv_state_q = conv_state_q.transpose(-1, -2)
374
+ conv_state_k = conv_state_k.transpose(-1, -2)
375
+ conv_state_v = conv_state_v.transpose(-1, -2)
376
+
377
+ ssm_states = mamba_cache_params.temporal
378
+
379
+ has_initial_state = forward_batch.extend_prefix_lens > 0
380
+
381
+ q_proj_states = q_proj_states.transpose(0, 1)
382
+ k_proj_states = k_proj_states.transpose(0, 1)
383
+ v_proj_states = v_proj_states.transpose(0, 1)
384
+
385
+ q = causal_conv1d_fn(
386
+ q_proj_states,
387
+ q_conv_weights,
388
+ q_conv_bias,
389
+ activation="silu",
390
+ conv_states=conv_state_q,
391
+ has_initial_state=has_initial_state,
392
+ cache_indices=cache_indices,
393
+ query_start_loc=query_start_loc,
394
+ seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
395
+ ).transpose(0, 1)
396
+
397
+ k = causal_conv1d_fn(
398
+ k_proj_states,
399
+ k_conv_weights,
400
+ k_conv_bias,
401
+ activation="silu",
402
+ conv_states=conv_state_k,
403
+ has_initial_state=has_initial_state,
404
+ cache_indices=cache_indices,
405
+ query_start_loc=query_start_loc,
406
+ seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
407
+ ).transpose(0, 1)
408
+
409
+ v = causal_conv1d_fn(
410
+ v_proj_states,
411
+ v_conv_weights,
412
+ v_conv_bias,
413
+ activation="silu",
414
+ conv_states=conv_state_v,
415
+ has_initial_state=has_initial_state,
416
+ cache_indices=cache_indices,
417
+ query_start_loc=query_start_loc,
418
+ seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
419
+ ).transpose(0, 1)
420
+
421
+ q, k, v = map(
422
+ lambda x: rearrange(x, "n (h d) -> 1 n h d", d=head_dim), (q, k, v)
423
+ )
424
+
425
+ beta = b_proj(hidden_states)[0].float().sigmoid()
426
+
427
+ g = f_b_proj(f_a_proj(hidden_states)[0])[0]
428
+ g = fused_kda_gate(g, A_log, head_dim, g_bias=dt_bias)
429
+
430
+ beta = beta.unsqueeze(0)
431
+ g = g.unsqueeze(0)
432
+
433
+ initial_state = ssm_states[cache_indices].contiguous()
434
+ (
435
+ core_attn_out,
436
+ last_recurrent_state,
437
+ ) = chunk_kda(
438
+ q=q,
439
+ k=k,
440
+ v=v,
441
+ g=g,
442
+ beta=beta,
443
+ initial_state=initial_state,
444
+ output_final_state=True,
445
+ use_qk_l2norm_in_kernel=True,
446
+ cu_seqlens=query_start_loc,
447
+ )
448
+ ssm_states[cache_indices] = last_recurrent_state
449
+
450
+ return core_attn_out
451
+
452
+
230
453
  class GDNAttnBackend(MambaAttnBackendBase):
231
454
  """Attention backend using Mamba kernel."""
232
455
 
@@ -13,16 +13,6 @@ from sglang.srt.distributed import (
13
13
  get_tensor_model_parallel_world_size,
14
14
  )
15
15
  from sglang.srt.distributed.utils import divide
16
- from sglang.srt.layers.attention.mamba.causal_conv1d import (
17
- causal_conv1d_fn,
18
- causal_conv1d_update,
19
- )
20
- from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
21
- causal_conv1d_fn as causal_conv1d_fn_triton,
22
- )
23
- from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
24
- causal_conv1d_update as causal_conv1d_update_triton,
25
- )
26
16
  from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
27
17
  from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated
28
18
  from sglang.srt.layers.attention.mamba.ops import (
@@ -40,7 +30,26 @@ from sglang.srt.model_loader.weight_utils import (
40
30
  composed_weight_loader,
41
31
  sharded_weight_loader,
42
32
  )
43
- from sglang.srt.utils import set_weight_attrs
33
+ from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
34
+
35
+ if is_cuda():
36
+ from sglang.srt.layers.attention.mamba.causal_conv1d import (
37
+ causal_conv1d_fn,
38
+ causal_conv1d_update,
39
+ )
40
+ from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
41
+ causal_conv1d_fn as causal_conv1d_fn_triton,
42
+ )
43
+ from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
44
+ causal_conv1d_update as causal_conv1d_update_triton,
45
+ )
46
+ elif is_npu():
47
+ from sgl_kernel_npu.mamba.causal_conv1d import (
48
+ causal_conv1d_fn_npu as causal_conv1d_fn,
49
+ )
50
+ from sgl_kernel_npu.mamba.causal_conv1d import (
51
+ causal_conv1d_update_npu as causal_conv1d_update,
52
+ )
44
53
 
45
54
  LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
46
55
 
@@ -22,6 +22,10 @@ def _dequantize_k_cache_slow(
22
22
  De-quantize the k-cache
23
23
  """
24
24
  assert dv % tile_size == 0
25
+ original_ndim = quant_k_cache.ndim
26
+ if original_ndim == 3:
27
+ # set block_size = 1
28
+ quant_k_cache = quant_k_cache.unsqueeze(1)
25
29
  num_tiles = dv // tile_size
26
30
  num_blocks, block_size, h_k, _ = quant_k_cache.shape
27
31
  assert h_k == 1
@@ -45,8 +49,10 @@ def _dequantize_k_cache_slow(
45
49
  cur_nope * cur_scales
46
50
  )
47
51
 
48
- result = result.view(num_blocks, block_size, 1, d)
49
- return result
52
+ if original_ndim == 3:
53
+ return result.view(num_blocks, 1, -1)
54
+ else:
55
+ return result.view(num_blocks, block_size, 1, -1)
50
56
 
51
57
 
52
58
  def _dequantize_k_cache_fast_wrapped(
@@ -54,7 +60,10 @@ def _dequantize_k_cache_fast_wrapped(
54
60
  dv: int = 512,
55
61
  tile_size: int = 128,
56
62
  ) -> torch.Tensor:
57
- # TODO the final API may be 2D instead of 4D, thus we convert them here
63
+ original_ndim = quant_k_cache.ndim
64
+ if original_ndim == 3:
65
+ # set block_size = 1
66
+ quant_k_cache = quant_k_cache.unsqueeze(1)
58
67
  num_blocks, block_size, _, dim_quant = quant_k_cache.shape
59
68
  assert dv == 512
60
69
  assert dim_quant == 656
@@ -63,7 +72,10 @@ def _dequantize_k_cache_fast_wrapped(
63
72
 
64
73
  output = _dequantize_k_cache_fast(quant_k_cache)
65
74
 
66
- return output.view(num_blocks, block_size, 1, -1)
75
+ if original_ndim == 3:
76
+ return output.view(num_blocks, 1, -1)
77
+ else:
78
+ return output.view(num_blocks, block_size, 1, -1)
67
79
 
68
80
 
69
81
  def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
@@ -85,7 +97,6 @@ def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
85
97
  assert num_blocks_per_token == 5
86
98
 
87
99
  assert dim_nope % group_size == 0
88
- NUM_NOPE_BLOCKS = dim_nope // group_size
89
100
 
90
101
  input_nope_q = quant_k_cache[:, :dim_nope]
91
102
  input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
@@ -102,7 +113,7 @@ def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
102
113
  input_nope_q.stride(0),
103
114
  input_nope_s.stride(0),
104
115
  input_rope.stride(0),
105
- NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
116
+ NUM_NOPE_BLOCKS=num_tiles,
106
117
  GROUP_SIZE=group_size,
107
118
  DIM_NOPE=dim_nope,
108
119
  DIM_ROPE=dim_rope,
@@ -159,5 +170,126 @@ def _dequantize_k_cache_fast_kernel(
159
170
  tl.store(dst_ptr, data, mask=mask)
160
171
 
161
172
 
173
+ def dequantize_k_cache_paged(
174
+ quant_k_cache: torch.Tensor,
175
+ page_table_1_flattened: torch.Tensor,
176
+ group_size: int = 128,
177
+ ) -> torch.Tensor:
178
+ """
179
+ De-quantize the k-cache with paged layout
180
+ Args:
181
+ quant_k_cache: [total_num_tokens, 1, dim_quant] or [num_blocks, block_size, 1, dim_quant], the quantized k-cache in paged layout
182
+ page_table_1_flattened: [num_tokens], the flattened page_table_1 with the page indices in each requests concatenated together
183
+ Returns:
184
+ output: [num_tokens, 1, dim_nope + dim_rope], the de-quantized k-cache
185
+ """
186
+ dim_quant = quant_k_cache.shape[-1]
187
+ assert (
188
+ dim_quant == 656
189
+ ), f"dim_quant: {dim_quant} != 656 detected in dequantize_k_cache_paged"
190
+ quant_k_cache = quant_k_cache.view((-1, dim_quant))
191
+
192
+ total_num_tokens, _ = quant_k_cache.shape
193
+ num_tokens = page_table_1_flattened.shape[0]
194
+ assert num_tokens <= total_num_tokens
195
+
196
+ assert quant_k_cache.dtype == torch.float8_e4m3fn
197
+ dim_nope = 512
198
+ dim_rope = 64
199
+ num_tiles = dim_nope // group_size # 512 // 128 = 4
200
+
201
+ output = torch.empty(
202
+ (num_tokens, 1, dim_nope + dim_rope),
203
+ dtype=torch.bfloat16,
204
+ device=quant_k_cache.device,
205
+ )
206
+
207
+ # cdiv(512 + 64, 128) = 5
208
+ num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
209
+ assert num_blocks_per_token == 5
210
+
211
+ assert dim_nope % group_size == 0
212
+
213
+ input_nope_q = quant_k_cache[:, :dim_nope]
214
+ # [:, 512:512+4*4] = [:, 512:528]
215
+ input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
216
+ torch.float32
217
+ )
218
+ # [:, 528:]
219
+ input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16)
220
+
221
+ _dequantize_k_cache_paged_kernel[(num_tokens, num_blocks_per_token)](
222
+ output,
223
+ input_nope_q,
224
+ input_nope_s,
225
+ input_rope,
226
+ page_table_1_flattened,
227
+ output.stride(0),
228
+ input_nope_q.stride(0),
229
+ input_nope_s.stride(0),
230
+ input_rope.stride(0),
231
+ NUM_NOPE_BLOCKS=num_tiles,
232
+ GROUP_SIZE=group_size,
233
+ DIM_NOPE=dim_nope,
234
+ DIM_ROPE=dim_rope,
235
+ )
236
+
237
+ return output
238
+
239
+
240
+ @triton.jit
241
+ def _dequantize_k_cache_paged_kernel(
242
+ output_ptr,
243
+ input_nope_q_ptr,
244
+ input_nope_s_ptr,
245
+ input_rope_ptr,
246
+ page_table_1_ptr,
247
+ output_stride_0: int,
248
+ input_nope_q_stride_0: int,
249
+ input_nope_s_stride_0: int,
250
+ input_rope_stride_0: int,
251
+ NUM_NOPE_BLOCKS: tl.constexpr,
252
+ GROUP_SIZE: tl.constexpr,
253
+ DIM_NOPE: tl.constexpr,
254
+ DIM_ROPE: tl.constexpr,
255
+ ):
256
+ token_id = tl.program_id(0)
257
+ token_id_paged = tl.load(page_table_1_ptr + token_id).to(tl.int32)
258
+ raw_block_id = tl.program_id(1)
259
+
260
+ if raw_block_id < NUM_NOPE_BLOCKS:
261
+ # a. dequant nope
262
+ effective_block_id = raw_block_id
263
+
264
+ offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
265
+ mask = offs_q < DIM_NOPE
266
+ ptr_q = input_nope_q_ptr + token_id_paged * input_nope_q_stride_0 + offs_q
267
+ ptr_s = (
268
+ input_nope_s_ptr
269
+ + token_id_paged * input_nope_s_stride_0
270
+ + effective_block_id
271
+ )
272
+
273
+ y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32)
274
+ y_s = tl.load(ptr_s)
275
+
276
+ y = (y_q * y_s).to(output_ptr.dtype.element_ty)
277
+
278
+ dst_ptr = output_ptr + token_id * output_stride_0 + offs_q
279
+ tl.store(dst_ptr, y, mask=mask)
280
+ else:
281
+ # b. copy rope
282
+ effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
283
+
284
+ offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
285
+ mask = offs < DIM_ROPE
286
+
287
+ src_ptr = input_rope_ptr + token_id_paged * input_rope_stride_0 + offs
288
+ dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs
289
+
290
+ data = tl.load(src_ptr, mask=mask).to(tl.bfloat16)
291
+ tl.store(dst_ptr, data, mask=mask)
292
+
293
+
162
294
  if __name__ == "__main__":
163
295
  raise Exception("UT is in quant_k_cache.py")
@@ -119,6 +119,7 @@ class Indexer(CustomOp):
119
119
  prefix: str = "",
120
120
  quant_config: Optional[QuantizationConfig] = None,
121
121
  alt_stream: Optional[torch.cuda.Stream] = None,
122
+ fuse_wk_and_weights_proj: bool = False,
122
123
  ):
123
124
  super().__init__()
124
125
  self.hidden_size = hidden_size
@@ -129,6 +130,7 @@ class Indexer(CustomOp):
129
130
  self.q_lora_rank = q_lora_rank
130
131
  self.layer_id = layer_id
131
132
  self.alt_stream = alt_stream
133
+ self.fuse_wk_and_weights_proj = fuse_wk_and_weights_proj
132
134
  if is_cuda():
133
135
  self.sm_count = deep_gemm.get_num_sms()
134
136
  self.half_device_sm_count = align(self.sm_count // 2, 8)
@@ -140,21 +142,29 @@ class Indexer(CustomOp):
140
142
  quant_config=quant_config,
141
143
  prefix=add_prefix("wq_b", prefix),
142
144
  )
143
- self.wk = ReplicatedLinear(
144
- self.hidden_size,
145
- self.head_dim,
146
- bias=False,
147
- quant_config=quant_config,
148
- prefix=add_prefix("wk", prefix),
149
- )
145
+ if self.fuse_wk_and_weights_proj:
146
+ self.fused_wk_and_weights_proj = ReplicatedLinear(
147
+ self.hidden_size,
148
+ self.head_dim + self.n_heads,
149
+ bias=False,
150
+ prefix=add_prefix("fused_wk_and_weights_proj", prefix),
151
+ )
152
+ else:
153
+ self.wk = ReplicatedLinear(
154
+ self.hidden_size,
155
+ self.head_dim,
156
+ bias=False,
157
+ quant_config=quant_config,
158
+ prefix=add_prefix("wk", prefix),
159
+ )
160
+ # NOTE: weight_proj is not quantized
161
+ self.weights_proj = ReplicatedLinear(
162
+ self.hidden_size,
163
+ self.n_heads,
164
+ bias=False,
165
+ prefix=add_prefix("weights_proj", prefix),
166
+ )
150
167
  self.k_norm = V32LayerNorm(self.head_dim)
151
- # NOTE: weight_proj is not quantized
152
- self.weights_proj = ReplicatedLinear(
153
- self.hidden_size,
154
- self.n_heads,
155
- bias=False,
156
- prefix=add_prefix("weights_proj", prefix),
157
- )
158
168
  self.rotary_emb = get_rope_wrapper(
159
169
  rope_head_dim,
160
170
  rotary_dim=rope_head_dim,
@@ -169,8 +179,7 @@ class Indexer(CustomOp):
169
179
  self.softmax_scale = self.head_dim**-0.5
170
180
 
171
181
  @torch.compile(dynamic=True)
172
- def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
173
- weights, _ = self.weights_proj(x)
182
+ def _get_logits_head_gate(self, weights: torch.Tensor, q_scale: torch.Tensor):
174
183
  weights = weights * self.n_heads**-0.5
175
184
  weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
176
185
  return weights
@@ -182,7 +191,7 @@ class Indexer(CustomOp):
182
191
  positions: torch.Tensor,
183
192
  enable_dual_stream: bool,
184
193
  ):
185
-
194
+ weights = None
186
195
  if enable_dual_stream:
187
196
  current_stream = torch.cuda.current_stream()
188
197
  self.alt_stream.wait_stream(current_stream)
@@ -199,7 +208,12 @@ class Indexer(CustomOp):
199
208
  )
200
209
  with torch.cuda.stream(self.alt_stream):
201
210
  # TODO we should also put DeepGEMM half SM here?
202
- key, _ = self.wk(x)
211
+ if self.fuse_wk_and_weights_proj:
212
+ key, weights = self.fused_wk_and_weights_proj(x)[0].split(
213
+ [self.head_dim, self.n_heads], dim=-1
214
+ )
215
+ else:
216
+ key, _ = self.wk(x)
203
217
  key = self.k_norm(key)
204
218
 
205
219
  k_rope, _ = torch.split(
@@ -217,7 +231,12 @@ class Indexer(CustomOp):
217
231
  query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
218
232
  )
219
233
 
220
- key, _ = self.wk(x)
234
+ if self.fuse_wk_and_weights_proj:
235
+ key, weights = self.fused_wk_and_weights_proj(x)[0].split(
236
+ [self.head_dim, self.n_heads], dim=-1
237
+ )
238
+ else:
239
+ key, _ = self.wk(x)
221
240
  key = self.k_norm(key)
222
241
  k_rope, _ = torch.split(
223
242
  key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
@@ -240,7 +259,7 @@ class Indexer(CustomOp):
240
259
  query = rotate_activation(query)
241
260
  key = rotate_activation(key)
242
261
 
243
- return query, key
262
+ return query, key, weights
244
263
 
245
264
  def _get_topk_paged(
246
265
  self,
@@ -490,7 +509,9 @@ class Indexer(CustomOp):
490
509
  if metadata is None:
491
510
  return None
492
511
 
493
- query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
512
+ query, key, weights = self._get_q_k_bf16(
513
+ q_lora, x, positions, enable_dual_stream
514
+ )
494
515
 
495
516
  if enable_dual_stream:
496
517
  current_stream = torch.cuda.current_stream()
@@ -517,7 +538,9 @@ class Indexer(CustomOp):
517
538
  index_k_scale=k_scale,
518
539
  )
519
540
 
520
- weights = self._get_logits_head_gate(x, q_scale)
541
+ if not self.fuse_wk_and_weights_proj:
542
+ weights, _ = self.weights_proj(x)
543
+ weights = self._get_logits_head_gate(weights, q_scale)
521
544
 
522
545
  if is_cuda():
523
546
  assert forward_batch.seq_lens_cpu is not None