sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +73 -14
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/launch_server.py +2 -0
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/compilation/backend.py +1 -1
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +30 -7
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -12
 - sglang/srt/entrypoints/engine.py +31 -20
 - sglang/srt/entrypoints/grpc_server.py +0 -1
 - sglang/srt/entrypoints/http_server.py +94 -94
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +23 -2
 - sglang/srt/eplb/expert_distribution.py +64 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/function_call/function_call_parser.py +2 -0
 - sglang/srt/function_call/minimax_m2.py +367 -0
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/activation.py +6 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +19 -8
 - sglang/srt/layers/attention/flashinfer_backend.py +10 -1
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
 - sglang/srt/layers/attention/utils.py +78 -0
 - sglang/srt/layers/communicator.py +24 -1
 - sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
 - sglang/srt/layers/layernorm.py +35 -6
 - sglang/srt/layers/logits_processor.py +9 -20
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
 - sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
 - sglang/srt/layers/moe/ep_moe/layer.py +78 -289
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
 - sglang/srt/layers/moe/moe_runner/runner.py +3 -0
 - sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
 - sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
 - sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
 - sglang/srt/layers/moe/topk.py +35 -10
 - sglang/srt/layers/moe/utils.py +3 -4
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +13 -84
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/awq.py +0 -3
 - sglang/srt/layers/quantization/base_config.py +7 -0
 - sglang/srt/layers/quantization/fp8.py +68 -63
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/gguf.py +566 -0
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/quantization/mxfp4.py +30 -38
 - sglang/srt/layers/quantization/unquant.py +23 -45
 - sglang/srt/layers/quantization/w4afp8.py +38 -2
 - sglang/srt/layers/radix_attention.py +5 -2
 - sglang/srt/layers/rotary_embedding.py +130 -46
 - sglang/srt/layers/sampler.py +12 -1
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +29 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +185 -144
 - sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +165 -78
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/hicache_storage.py +7 -1
 - sglang/srt/mem_cache/memory_pool.py +253 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +55 -14
 - sglang/srt/model_executor/model_runner.py +77 -170
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +296 -78
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4_moe.py +322 -354
 - sglang/srt/models/glm4_moe_nextn.py +4 -14
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +29 -197
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +922 -0
 - sglang/srt/models/nvila.py +355 -0
 - sglang/srt/models/nvila_lite.py +184 -0
 - sglang/srt/models/qwen2.py +23 -2
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +35 -5
 - sglang/srt/models/qwen3_moe.py +18 -12
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multimodal/processors/base_processor.py +1 -0
 - sglang/srt/multimodal/processors/glm4v.py +1 -1
 - sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
 - sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/parser/reasoning_parser.py +28 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +459 -199
 - sglang/srt/single_batch_overlap.py +2 -4
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +142 -74
 - sglang/srt/utils/hf_transformers_utils.py +38 -12
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_deterministic.py +235 -12
 - sglang/test/test_deterministic_utils.py +2 -1
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
 - sglang/srt/models/vila.py +0 -306
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -119,6 +119,7 @@ class Indexer(CustomOp): 
     | 
|
| 
       119 
119 
     | 
    
         
             
                    prefix: str = "",
         
     | 
| 
       120 
120 
     | 
    
         
             
                    quant_config: Optional[QuantizationConfig] = None,
         
     | 
| 
       121 
121 
     | 
    
         
             
                    alt_stream: Optional[torch.cuda.Stream] = None,
         
     | 
| 
      
 122 
     | 
    
         
            +
                    fuse_wk_and_weights_proj: bool = False,
         
     | 
| 
       122 
123 
     | 
    
         
             
                ):
         
     | 
| 
       123 
124 
     | 
    
         
             
                    super().__init__()
         
     | 
| 
       124 
125 
     | 
    
         
             
                    self.hidden_size = hidden_size
         
     | 
| 
         @@ -129,6 +130,7 @@ class Indexer(CustomOp): 
     | 
|
| 
       129 
130 
     | 
    
         
             
                    self.q_lora_rank = q_lora_rank
         
     | 
| 
       130 
131 
     | 
    
         
             
                    self.layer_id = layer_id
         
     | 
| 
       131 
132 
     | 
    
         
             
                    self.alt_stream = alt_stream
         
     | 
| 
      
 133 
     | 
    
         
            +
                    self.fuse_wk_and_weights_proj = fuse_wk_and_weights_proj
         
     | 
| 
       132 
134 
     | 
    
         
             
                    if is_cuda():
         
     | 
| 
       133 
135 
     | 
    
         
             
                        self.sm_count = deep_gemm.get_num_sms()
         
     | 
| 
       134 
136 
     | 
    
         
             
                        self.half_device_sm_count = align(self.sm_count // 2, 8)
         
     | 
| 
         @@ -140,21 +142,29 @@ class Indexer(CustomOp): 
     | 
|
| 
       140 
142 
     | 
    
         
             
                        quant_config=quant_config,
         
     | 
| 
       141 
143 
     | 
    
         
             
                        prefix=add_prefix("wq_b", prefix),
         
     | 
| 
       142 
144 
     | 
    
         
             
                    )
         
     | 
| 
       143 
     | 
    
         
            -
                    self. 
     | 
| 
       144 
     | 
    
         
            -
                        self. 
     | 
| 
       145 
     | 
    
         
            -
             
     | 
| 
       146 
     | 
    
         
            -
             
     | 
| 
       147 
     | 
    
         
            -
             
     | 
| 
       148 
     | 
    
         
            -
             
     | 
| 
       149 
     | 
    
         
            -
             
     | 
| 
      
 145 
     | 
    
         
            +
                    if self.fuse_wk_and_weights_proj:
         
     | 
| 
      
 146 
     | 
    
         
            +
                        self.fused_wk_and_weights_proj = ReplicatedLinear(
         
     | 
| 
      
 147 
     | 
    
         
            +
                            self.hidden_size,
         
     | 
| 
      
 148 
     | 
    
         
            +
                            self.head_dim + self.n_heads,
         
     | 
| 
      
 149 
     | 
    
         
            +
                            bias=False,
         
     | 
| 
      
 150 
     | 
    
         
            +
                            prefix=add_prefix("fused_wk_and_weights_proj", prefix),
         
     | 
| 
      
 151 
     | 
    
         
            +
                        )
         
     | 
| 
      
 152 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 153 
     | 
    
         
            +
                        self.wk = ReplicatedLinear(
         
     | 
| 
      
 154 
     | 
    
         
            +
                            self.hidden_size,
         
     | 
| 
      
 155 
     | 
    
         
            +
                            self.head_dim,
         
     | 
| 
      
 156 
     | 
    
         
            +
                            bias=False,
         
     | 
| 
      
 157 
     | 
    
         
            +
                            quant_config=quant_config,
         
     | 
| 
      
 158 
     | 
    
         
            +
                            prefix=add_prefix("wk", prefix),
         
     | 
| 
      
 159 
     | 
    
         
            +
                        )
         
     | 
| 
      
 160 
     | 
    
         
            +
                        # NOTE: weight_proj is not quantized
         
     | 
| 
      
 161 
     | 
    
         
            +
                        self.weights_proj = ReplicatedLinear(
         
     | 
| 
      
 162 
     | 
    
         
            +
                            self.hidden_size,
         
     | 
| 
      
 163 
     | 
    
         
            +
                            self.n_heads,
         
     | 
| 
      
 164 
     | 
    
         
            +
                            bias=False,
         
     | 
| 
      
 165 
     | 
    
         
            +
                            prefix=add_prefix("weights_proj", prefix),
         
     | 
| 
      
 166 
     | 
    
         
            +
                        )
         
     | 
| 
       150 
167 
     | 
    
         
             
                    self.k_norm = V32LayerNorm(self.head_dim)
         
     | 
| 
       151 
     | 
    
         
            -
                    # NOTE: weight_proj is not quantized
         
     | 
| 
       152 
     | 
    
         
            -
                    self.weights_proj = ReplicatedLinear(
         
     | 
| 
       153 
     | 
    
         
            -
                        self.hidden_size,
         
     | 
| 
       154 
     | 
    
         
            -
                        self.n_heads,
         
     | 
| 
       155 
     | 
    
         
            -
                        bias=False,
         
     | 
| 
       156 
     | 
    
         
            -
                        prefix=add_prefix("weights_proj", prefix),
         
     | 
| 
       157 
     | 
    
         
            -
                    )
         
     | 
| 
       158 
168 
     | 
    
         
             
                    self.rotary_emb = get_rope_wrapper(
         
     | 
| 
       159 
169 
     | 
    
         
             
                        rope_head_dim,
         
     | 
| 
       160 
170 
     | 
    
         
             
                        rotary_dim=rope_head_dim,
         
     | 
| 
         @@ -169,8 +179,7 @@ class Indexer(CustomOp): 
     | 
|
| 
       169 
179 
     | 
    
         
             
                    self.softmax_scale = self.head_dim**-0.5
         
     | 
| 
       170 
180 
     | 
    
         | 
| 
       171 
181 
     | 
    
         
             
                @torch.compile(dynamic=True)
         
     | 
| 
       172 
     | 
    
         
            -
                def _get_logits_head_gate(self,  
     | 
| 
       173 
     | 
    
         
            -
                    weights, _ = self.weights_proj(x)
         
     | 
| 
      
 182 
     | 
    
         
            +
                def _get_logits_head_gate(self, weights: torch.Tensor, q_scale: torch.Tensor):
         
     | 
| 
       174 
183 
     | 
    
         
             
                    weights = weights * self.n_heads**-0.5
         
     | 
| 
       175 
184 
     | 
    
         
             
                    weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
         
     | 
| 
       176 
185 
     | 
    
         
             
                    return weights
         
     | 
| 
         @@ -182,7 +191,7 @@ class Indexer(CustomOp): 
     | 
|
| 
       182 
191 
     | 
    
         
             
                    positions: torch.Tensor,
         
     | 
| 
       183 
192 
     | 
    
         
             
                    enable_dual_stream: bool,
         
     | 
| 
       184 
193 
     | 
    
         
             
                ):
         
     | 
| 
       185 
     | 
    
         
            -
             
     | 
| 
      
 194 
     | 
    
         
            +
                    weights = None
         
     | 
| 
       186 
195 
     | 
    
         
             
                    if enable_dual_stream:
         
     | 
| 
       187 
196 
     | 
    
         
             
                        current_stream = torch.cuda.current_stream()
         
     | 
| 
       188 
197 
     | 
    
         
             
                        self.alt_stream.wait_stream(current_stream)
         
     | 
| 
         @@ -199,7 +208,12 @@ class Indexer(CustomOp): 
     | 
|
| 
       199 
208 
     | 
    
         
             
                            )
         
     | 
| 
       200 
209 
     | 
    
         
             
                        with torch.cuda.stream(self.alt_stream):
         
     | 
| 
       201 
210 
     | 
    
         
             
                            # TODO we should also put DeepGEMM half SM here?
         
     | 
| 
       202 
     | 
    
         
            -
                             
     | 
| 
      
 211 
     | 
    
         
            +
                            if self.fuse_wk_and_weights_proj:
         
     | 
| 
      
 212 
     | 
    
         
            +
                                key, weights = self.fused_wk_and_weights_proj(x)[0].split(
         
     | 
| 
      
 213 
     | 
    
         
            +
                                    [self.head_dim, self.n_heads], dim=-1
         
     | 
| 
      
 214 
     | 
    
         
            +
                                )
         
     | 
| 
      
 215 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 216 
     | 
    
         
            +
                                key, _ = self.wk(x)
         
     | 
| 
       203 
217 
     | 
    
         
             
                            key = self.k_norm(key)
         
     | 
| 
       204 
218 
     | 
    
         | 
| 
       205 
219 
     | 
    
         
             
                            k_rope, _ = torch.split(
         
     | 
| 
         @@ -217,7 +231,12 @@ class Indexer(CustomOp): 
     | 
|
| 
       217 
231 
     | 
    
         
             
                            query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
         
     | 
| 
       218 
232 
     | 
    
         
             
                        )
         
     | 
| 
       219 
233 
     | 
    
         | 
| 
       220 
     | 
    
         
            -
                         
     | 
| 
      
 234 
     | 
    
         
            +
                        if self.fuse_wk_and_weights_proj:
         
     | 
| 
      
 235 
     | 
    
         
            +
                            key, weights = self.fused_wk_and_weights_proj(x)[0].split(
         
     | 
| 
      
 236 
     | 
    
         
            +
                                [self.head_dim, self.n_heads], dim=-1
         
     | 
| 
      
 237 
     | 
    
         
            +
                            )
         
     | 
| 
      
 238 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 239 
     | 
    
         
            +
                            key, _ = self.wk(x)
         
     | 
| 
       221 
240 
     | 
    
         
             
                        key = self.k_norm(key)
         
     | 
| 
       222 
241 
     | 
    
         
             
                        k_rope, _ = torch.split(
         
     | 
| 
       223 
242 
     | 
    
         
             
                            key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
         
     | 
| 
         @@ -240,7 +259,7 @@ class Indexer(CustomOp): 
     | 
|
| 
       240 
259 
     | 
    
         
             
                        query = rotate_activation(query)
         
     | 
| 
       241 
260 
     | 
    
         
             
                        key = rotate_activation(key)
         
     | 
| 
       242 
261 
     | 
    
         | 
| 
       243 
     | 
    
         
            -
                    return query, key
         
     | 
| 
      
 262 
     | 
    
         
            +
                    return query, key, weights
         
     | 
| 
       244 
263 
     | 
    
         | 
| 
       245 
264 
     | 
    
         
             
                def _get_topk_paged(
         
     | 
| 
       246 
265 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -490,7 +509,9 @@ class Indexer(CustomOp): 
     | 
|
| 
       490 
509 
     | 
    
         
             
                    if metadata is None:
         
     | 
| 
       491 
510 
     | 
    
         
             
                        return None
         
     | 
| 
       492 
511 
     | 
    
         | 
| 
       493 
     | 
    
         
            -
                    query, key = self._get_q_k_bf16( 
     | 
| 
      
 512 
     | 
    
         
            +
                    query, key, weights = self._get_q_k_bf16(
         
     | 
| 
      
 513 
     | 
    
         
            +
                        q_lora, x, positions, enable_dual_stream
         
     | 
| 
      
 514 
     | 
    
         
            +
                    )
         
     | 
| 
       494 
515 
     | 
    
         | 
| 
       495 
516 
     | 
    
         
             
                    if enable_dual_stream:
         
     | 
| 
       496 
517 
     | 
    
         
             
                        current_stream = torch.cuda.current_stream()
         
     | 
| 
         @@ -517,7 +538,9 @@ class Indexer(CustomOp): 
     | 
|
| 
       517 
538 
     | 
    
         
             
                        index_k_scale=k_scale,
         
     | 
| 
       518 
539 
     | 
    
         
             
                    )
         
     | 
| 
       519 
540 
     | 
    
         | 
| 
       520 
     | 
    
         
            -
                     
     | 
| 
      
 541 
     | 
    
         
            +
                    if not self.fuse_wk_and_weights_proj:
         
     | 
| 
      
 542 
     | 
    
         
            +
                        weights, _ = self.weights_proj(x)
         
     | 
| 
      
 543 
     | 
    
         
            +
                    weights = self._get_logits_head_gate(weights, q_scale)
         
     | 
| 
       521 
544 
     | 
    
         | 
| 
       522 
545 
     | 
    
         
             
                    if is_cuda():
         
     | 
| 
       523 
546 
     | 
    
         
             
                        assert forward_batch.seq_lens_cpu is not None
         
     | 
| 
         @@ -206,6 +206,8 @@ def _quantize_k_cache_fast_kernel( 
     | 
|
| 
       206 
206 
     | 
    
         | 
| 
       207 
207 
     | 
    
         | 
| 
       208 
208 
     | 
    
         
             
            if __name__ == "__main__":
         
     | 
| 
      
 209 
     | 
    
         
            +
                import dequant_k_cache
         
     | 
| 
      
 210 
     | 
    
         
            +
             
     | 
| 
       209 
211 
     | 
    
         
             
                for num_blocks, block_size in [
         
     | 
| 
       210 
212 
     | 
    
         
             
                    (1, 1),
         
     | 
| 
       211 
213 
     | 
    
         
             
                    (10, 64),
         
     | 
| 
         @@ -217,21 +219,9 @@ if __name__ == "__main__": 
     | 
|
| 
       217 
219 
     | 
    
         
             
                        dtype=torch.bfloat16,
         
     | 
| 
       218 
220 
     | 
    
         
             
                        device="cuda",
         
     | 
| 
       219 
221 
     | 
    
         
             
                    )
         
     | 
| 
       220 
     | 
    
         
            -
                    # temp debug
         
     | 
| 
       221 
     | 
    
         
            -
                    # input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
         
     | 
| 
       222 
222 
     | 
    
         | 
| 
       223 
223 
     | 
    
         
             
                    ref_quant = _quantize_k_cache_slow(input_k_cache)
         
     | 
| 
       224 
224 
     | 
    
         
             
                    actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
         
     | 
| 
       225 
     | 
    
         
            -
                    # print(f"{input_k_cache=}")
         
     | 
| 
       226 
     | 
    
         
            -
                    # print(f"{ref_quant=}")
         
     | 
| 
       227 
     | 
    
         
            -
                    # print(f"{actual_quant=}")
         
     | 
| 
       228 
     | 
    
         
            -
                    # print(f"{ref_quant == actual_quant=}")
         
     | 
| 
       229 
     | 
    
         
            -
                    # print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
         
     | 
| 
       230 
     | 
    
         
            -
                    # print(f"{ref_quant.view(torch.bfloat16)=}")
         
     | 
| 
       231 
     | 
    
         
            -
                    # print(f"{actual_quant.view(torch.bfloat16)=}")
         
     | 
| 
       232 
     | 
    
         
            -
                    # assert torch.all(ref_quant == actual_quant)
         
     | 
| 
       233 
     | 
    
         
            -
             
     | 
| 
       234 
     | 
    
         
            -
                    import dequant_k_cache
         
     | 
| 
       235 
225 
     | 
    
         | 
| 
       236 
226 
     | 
    
         
             
                    ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)
         
     | 
| 
       237 
227 
     | 
    
         
             
                    ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)
         
     | 
| 
         @@ -252,4 +242,46 @@ if __name__ == "__main__": 
     | 
|
| 
       252 
242 
     | 
    
         
             
                        ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2
         
     | 
| 
       253 
243 
     | 
    
         
             
                    )
         
     | 
| 
       254 
244 
     | 
    
         | 
| 
      
 245 
     | 
    
         
            +
                    # test dequant_k_cache_paged
         
     | 
| 
      
 246 
     | 
    
         
            +
                    page_table_1 = torch.arange(
         
     | 
| 
      
 247 
     | 
    
         
            +
                        num_blocks * block_size, dtype=torch.int32, device="cuda"
         
     | 
| 
      
 248 
     | 
    
         
            +
                    )
         
     | 
| 
      
 249 
     | 
    
         
            +
                    actual_dequant_paged = dequant_k_cache.dequantize_k_cache_paged(
         
     | 
| 
      
 250 
     | 
    
         
            +
                        actual_quant, page_table_1
         
     | 
| 
      
 251 
     | 
    
         
            +
                    ).reshape(actual_actual_dequant.shape)
         
     | 
| 
      
 252 
     | 
    
         
            +
                    print(f"{torch.mean(actual_actual_dequant - actual_dequant_paged)=}")
         
     | 
| 
      
 253 
     | 
    
         
            +
                    torch.testing.assert_close(
         
     | 
| 
      
 254 
     | 
    
         
            +
                        ref_ref_dequant, actual_dequant_paged, atol=0.2, rtol=0.2
         
     | 
| 
      
 255 
     | 
    
         
            +
                    )
         
     | 
| 
      
 256 
     | 
    
         
            +
             
     | 
| 
       255 
257 
     | 
    
         
             
                print("Passed")
         
     | 
| 
      
 258 
     | 
    
         
            +
                print("Do benchmark...")
         
     | 
| 
      
 259 
     | 
    
         
            +
             
     | 
| 
      
 260 
     | 
    
         
            +
                for num_blocks, block_size in [
         
     | 
| 
      
 261 
     | 
    
         
            +
                    (1, 64),
         
     | 
| 
      
 262 
     | 
    
         
            +
                    (64, 64),
         
     | 
| 
      
 263 
     | 
    
         
            +
                    (128, 64),
         
     | 
| 
      
 264 
     | 
    
         
            +
                    (256, 64),
         
     | 
| 
      
 265 
     | 
    
         
            +
                    (512, 64),
         
     | 
| 
      
 266 
     | 
    
         
            +
                    (1024, 64),
         
     | 
| 
      
 267 
     | 
    
         
            +
                    (2048, 64),
         
     | 
| 
      
 268 
     | 
    
         
            +
                ]:
         
     | 
| 
      
 269 
     | 
    
         
            +
                    dim_nope_and_rope = 512 + 64
         
     | 
| 
      
 270 
     | 
    
         
            +
             
     | 
| 
      
 271 
     | 
    
         
            +
                    input_k_cache = torch.randn(
         
     | 
| 
      
 272 
     | 
    
         
            +
                        (num_blocks, block_size, 1, dim_nope_and_rope),
         
     | 
| 
      
 273 
     | 
    
         
            +
                        dtype=torch.bfloat16,
         
     | 
| 
      
 274 
     | 
    
         
            +
                        device="cuda",
         
     | 
| 
      
 275 
     | 
    
         
            +
                    )
         
     | 
| 
      
 276 
     | 
    
         
            +
             
     | 
| 
      
 277 
     | 
    
         
            +
                    actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
         
     | 
| 
      
 278 
     | 
    
         
            +
             
     | 
| 
      
 279 
     | 
    
         
            +
                    page_table_1 = torch.arange(
         
     | 
| 
      
 280 
     | 
    
         
            +
                        num_blocks * block_size, dtype=torch.int32, device="cuda"
         
     | 
| 
      
 281 
     | 
    
         
            +
                    )
         
     | 
| 
      
 282 
     | 
    
         
            +
             
     | 
| 
      
 283 
     | 
    
         
            +
                    def run_ans():
         
     | 
| 
      
 284 
     | 
    
         
            +
                        return dequant_k_cache.dequantize_k_cache_paged(actual_quant, page_table_1)
         
     | 
| 
      
 285 
     | 
    
         
            +
             
     | 
| 
      
 286 
     | 
    
         
            +
                    ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000  # type: ignore
         
     | 
| 
      
 287 
     | 
    
         
            +
                    print(f"seq_kv: {num_blocks * block_size}, time: {ans_time * 1e6: 4.0f} us")
         
     | 
| 
         @@ -103,7 +103,7 @@ def transform_index_page_table_decode_ref( 
     | 
|
| 
       103 
103 
     | 
    
         
             
                    result = torch.empty_like(topk_indices, dtype=torch.int32)
         
     | 
| 
       104 
104 
     | 
    
         
             
                assert result.shape == topk_indices.shape
         
     | 
| 
       105 
105 
     | 
    
         
             
                torch.gather(
         
     | 
| 
       106 
     | 
    
         
            -
                    page_table,
         
     | 
| 
      
 106 
     | 
    
         
            +
                    page_table.to(result.dtype),
         
     | 
| 
       107 
107 
     | 
    
         
             
                    dim=1,
         
     | 
| 
       108 
108 
     | 
    
         
             
                    index=topk_indices.clamp(min=0),
         
     | 
| 
       109 
109 
     | 
    
         
             
                    out=result,
         
     | 
| 
         @@ -1,12 +1,14 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from __future__ import annotations
         
     | 
| 
       2 
2 
     | 
    
         | 
| 
       3 
3 
     | 
    
         
             
            from dataclasses import dataclass
         
     | 
| 
      
 4 
     | 
    
         
            +
            from enum import IntEnum, auto
         
     | 
| 
       4 
5 
     | 
    
         
             
            from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
         
     | 
| 
       5 
6 
     | 
    
         | 
| 
       6 
7 
     | 
    
         
             
            import torch
         
     | 
| 
       7 
8 
     | 
    
         | 
| 
       8 
9 
     | 
    
         
             
            from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
         
     | 
| 
       9 
10 
     | 
    
         
             
            from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
         
     | 
| 
      
 11 
     | 
    
         
            +
            from sglang.srt.layers.attention.nsa.dequant_k_cache import dequantize_k_cache_paged
         
     | 
| 
       10 
12 
     | 
    
         
             
            from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
         
     | 
| 
       11 
13 
     | 
    
         
             
            from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
         
     | 
| 
       12 
14 
     | 
    
         
             
            from sglang.srt.layers.attention.nsa.transform_index import (
         
     | 
| 
         @@ -98,11 +100,27 @@ class NSAMetadata: 
     | 
|
| 
       98 
100 
     | 
    
         
             
                nsa_max_seqlen_q: Literal[1] = 1  # always 1 for decode, variable for extend
         
     | 
| 
       99 
101 
     | 
    
         | 
| 
       100 
102 
     | 
    
         
             
                flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
         
     | 
| 
      
 103 
     | 
    
         
            +
                # The sum of sequence lengths for key, prefill only
         
     | 
| 
      
 104 
     | 
    
         
            +
                seq_lens_sum: Optional[int] = None
         
     | 
| 
      
 105 
     | 
    
         
            +
                # The flattened 1D page table with shape (seq_lens_sum,), prefill only
         
     | 
| 
      
 106 
     | 
    
         
            +
                # this table is always with page_size = 1
         
     | 
| 
      
 107 
     | 
    
         
            +
                page_table_1_flattened: Optional[torch.Tensor] = None
         
     | 
| 
      
 108 
     | 
    
         
            +
                # The offset of topk indices in ragged kv, prefill only
         
     | 
| 
      
 109 
     | 
    
         
            +
                # shape: (seq_lens_sum,)
         
     | 
| 
      
 110 
     | 
    
         
            +
                topk_indices_offset: Optional[torch.Tensor] = None
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
      
 113 
     | 
    
         
            +
            class TopkTransformMethod(IntEnum):
         
     | 
| 
      
 114 
     | 
    
         
            +
                # Transform topk indices to indices to the page table (page_size = 1)
         
     | 
| 
      
 115 
     | 
    
         
            +
                PAGED = auto()
         
     | 
| 
      
 116 
     | 
    
         
            +
                # Transform topk indices to indices to ragged kv (non-paged)
         
     | 
| 
      
 117 
     | 
    
         
            +
                RAGGED = auto()
         
     | 
| 
       101 
118 
     | 
    
         | 
| 
       102 
119 
     | 
    
         | 
| 
       103 
120 
     | 
    
         
             
            @dataclass(frozen=True)
         
     | 
| 
       104 
121 
     | 
    
         
             
            class NSAIndexerMetadata(BaseIndexerMetadata):
         
     | 
| 
       105 
122 
     | 
    
         
             
                attn_metadata: NSAMetadata
         
     | 
| 
      
 123 
     | 
    
         
            +
                topk_transform_method: TopkTransformMethod
         
     | 
| 
       106 
124 
     | 
    
         | 
| 
       107 
125 
     | 
    
         
             
                def get_seqlens_int32(self) -> torch.Tensor:
         
     | 
| 
       108 
126 
     | 
    
         
             
                    return self.attn_metadata.cache_seqlens_int32
         
     | 
| 
         @@ -118,23 +136,36 @@ class NSAIndexerMetadata(BaseIndexerMetadata): 
     | 
|
| 
       118 
136 
     | 
    
         
             
                    logits: torch.Tensor,
         
     | 
| 
       119 
137 
     | 
    
         
             
                    topk: int,
         
     | 
| 
       120 
138 
     | 
    
         
             
                ) -> torch.Tensor:
         
     | 
| 
       121 
     | 
    
         
            -
                    from sgl_kernel import  
     | 
| 
      
 139 
     | 
    
         
            +
                    from sgl_kernel import (
         
     | 
| 
      
 140 
     | 
    
         
            +
                        fast_topk_transform_fused,
         
     | 
| 
      
 141 
     | 
    
         
            +
                        fast_topk_transform_ragged_fused,
         
     | 
| 
      
 142 
     | 
    
         
            +
                        fast_topk_v2,
         
     | 
| 
      
 143 
     | 
    
         
            +
                    )
         
     | 
| 
       122 
144 
     | 
    
         | 
| 
       123 
145 
     | 
    
         
             
                    if not NSA_FUSE_TOPK:
         
     | 
| 
       124 
146 
     | 
    
         
             
                        return fast_topk_v2(logits, self.get_seqlens_expanded(), topk)
         
     | 
| 
       125 
     | 
    
         
            -
             
     | 
| 
       126 
     | 
    
         
            -
             
     | 
| 
       127 
     | 
    
         
            -
             
     | 
| 
       128 
     | 
    
         
            -
             
     | 
| 
       129 
     | 
    
         
            -
             
     | 
| 
       130 
     | 
    
         
            -
             
     | 
| 
       131 
     | 
    
         
            -
             
     | 
| 
       132 
     | 
    
         
            -
             
     | 
| 
       133 
     | 
    
         
            -
             
     | 
| 
      
 147 
     | 
    
         
            +
                    elif self.topk_transform_method == TopkTransformMethod.PAGED:
         
     | 
| 
      
 148 
     | 
    
         
            +
                        # NOTE(dark): if fused, we return a transformed page table directly
         
     | 
| 
      
 149 
     | 
    
         
            +
                        return fast_topk_transform_fused(
         
     | 
| 
      
 150 
     | 
    
         
            +
                            score=logits,
         
     | 
| 
      
 151 
     | 
    
         
            +
                            lengths=self.get_seqlens_expanded(),
         
     | 
| 
      
 152 
     | 
    
         
            +
                            page_table_size_1=self.attn_metadata.page_table_1,
         
     | 
| 
      
 153 
     | 
    
         
            +
                            cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
         
     | 
| 
      
 154 
     | 
    
         
            +
                            topk=topk,
         
     | 
| 
      
 155 
     | 
    
         
            +
                        )
         
     | 
| 
      
 156 
     | 
    
         
            +
                    elif self.topk_transform_method == TopkTransformMethod.RAGGED:
         
     | 
| 
      
 157 
     | 
    
         
            +
                        return fast_topk_transform_ragged_fused(
         
     | 
| 
      
 158 
     | 
    
         
            +
                            score=logits,
         
     | 
| 
      
 159 
     | 
    
         
            +
                            lengths=self.get_seqlens_expanded(),
         
     | 
| 
      
 160 
     | 
    
         
            +
                            topk_indices_offset=self.attn_metadata.topk_indices_offset,
         
     | 
| 
      
 161 
     | 
    
         
            +
                            topk=topk,
         
     | 
| 
      
 162 
     | 
    
         
            +
                        )
         
     | 
| 
      
 163 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 164 
     | 
    
         
            +
                        assert False, f"Unsupported {self.topk_transform_method = }"
         
     | 
| 
       134 
165 
     | 
    
         | 
| 
       135 
166 
     | 
    
         | 
| 
       136 
167 
     | 
    
         
             
            def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
         
     | 
| 
       137 
     | 
    
         
            -
                assert seqlens.dtype == torch.int32 
     | 
| 
      
 168 
     | 
    
         
            +
                assert seqlens.dtype == torch.int32
         
     | 
| 
       138 
169 
     | 
    
         
             
                return torch.nn.functional.pad(
         
     | 
| 
       139 
170 
     | 
    
         
             
                    torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
         
     | 
| 
       140 
171 
     | 
    
         
             
                )
         
     | 
| 
         @@ -181,6 +212,7 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       181 
212 
     | 
    
         
             
                    global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
         
     | 
| 
       182 
213 
     | 
    
         
             
                    NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend
         
     | 
| 
       183 
214 
     | 
    
         
             
                    NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend
         
     | 
| 
      
 215 
     | 
    
         
            +
                    self.enable_auto_select_prefill_impl = NSA_PREFILL_IMPL == "flashmla_auto"
         
     | 
| 
       184 
216 
     | 
    
         | 
| 
       185 
217 
     | 
    
         
             
                    self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
         
     | 
| 
       186 
218 
     | 
    
         | 
| 
         @@ -231,10 +263,16 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       231 
263 
     | 
    
         
             
                    cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
         
     | 
| 
       232 
264 
     | 
    
         
             
                    assert forward_batch.seq_lens_cpu is not None
         
     | 
| 
       233 
265 
     | 
    
         
             
                    max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item() + draft_token_num)
         
     | 
| 
      
 266 
     | 
    
         
            +
                    # [b, max_seqlen_k]
         
     | 
| 
       234 
267 
     | 
    
         
             
                    page_table = forward_batch.req_to_token_pool.req_to_token[
         
     | 
| 
       235 
268 
     | 
    
         
             
                        forward_batch.req_pool_indices, :max_seqlen_k
         
     | 
| 
       236 
269 
     | 
    
         
             
                    ]
         
     | 
| 
       237 
270 
     | 
    
         | 
| 
      
 271 
     | 
    
         
            +
                    page_table_1_flattened = None
         
     | 
| 
      
 272 
     | 
    
         
            +
                    topk_indices_offset = None
         
     | 
| 
      
 273 
     | 
    
         
            +
                    self.set_nsa_prefill_impl(forward_batch)
         
     | 
| 
      
 274 
     | 
    
         
            +
                    topk_transform_method = self.get_topk_transform_method()
         
     | 
| 
      
 275 
     | 
    
         
            +
             
     | 
| 
       238 
276 
     | 
    
         
             
                    if forward_batch.forward_mode.is_decode_or_idle():
         
     | 
| 
       239 
277 
     | 
    
         
             
                        extend_seq_lens_cpu = [1] * batch_size
         
     | 
| 
       240 
278 
     | 
    
         
             
                        max_seqlen_q = 1
         
     | 
| 
         @@ -295,6 +333,7 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       295 
333 
     | 
    
         
             
                        else:
         
     | 
| 
       296 
334 
     | 
    
         
             
                            max_seqlen_q = max_seqlen_k
         
     | 
| 
       297 
335 
     | 
    
         
             
                            cu_seqlens_q = cu_seqlens_k
         
     | 
| 
      
 336 
     | 
    
         
            +
             
     | 
| 
       298 
337 
     | 
    
         
             
                        seqlens_expanded = torch.cat(
         
     | 
| 
       299 
338 
     | 
    
         
             
                            [
         
     | 
| 
       300 
339 
     | 
    
         
             
                                torch.arange(
         
     | 
| 
         @@ -310,6 +349,24 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       310 
349 
     | 
    
         
             
                                )
         
     | 
| 
       311 
350 
     | 
    
         
             
                            ]
         
     | 
| 
       312 
351 
     | 
    
         
             
                        )
         
     | 
| 
      
 352 
     | 
    
         
            +
             
     | 
| 
      
 353 
     | 
    
         
            +
                        if topk_transform_method == TopkTransformMethod.RAGGED:
         
     | 
| 
      
 354 
     | 
    
         
            +
                            page_table_1_flattened = torch.cat(
         
     | 
| 
      
 355 
     | 
    
         
            +
                                [
         
     | 
| 
      
 356 
     | 
    
         
            +
                                    page_table[i, :kv_len]
         
     | 
| 
      
 357 
     | 
    
         
            +
                                    for i, kv_len in enumerate(
         
     | 
| 
      
 358 
     | 
    
         
            +
                                        forward_batch.seq_lens_cpu.tolist(),
         
     | 
| 
      
 359 
     | 
    
         
            +
                                    )
         
     | 
| 
      
 360 
     | 
    
         
            +
                                ]
         
     | 
| 
      
 361 
     | 
    
         
            +
                            )
         
     | 
| 
      
 362 
     | 
    
         
            +
                            assert (
         
     | 
| 
      
 363 
     | 
    
         
            +
                                page_table_1_flattened.shape[0] == forward_batch.seq_lens_sum
         
     | 
| 
      
 364 
     | 
    
         
            +
                            ), f"{page_table_1_flattened.shape[0] = } must be the same as {forward_batch.seq_lens_sum = }"
         
     | 
| 
      
 365 
     | 
    
         
            +
             
     | 
| 
      
 366 
     | 
    
         
            +
                            topk_indices_offset = torch.repeat_interleave(
         
     | 
| 
      
 367 
     | 
    
         
            +
                                cu_seqlens_k[:-1],
         
     | 
| 
      
 368 
     | 
    
         
            +
                                forward_batch.extend_seq_lens,
         
     | 
| 
      
 369 
     | 
    
         
            +
                            )
         
     | 
| 
       313 
370 
     | 
    
         
             
                    else:
         
     | 
| 
       314 
371 
     | 
    
         
             
                        assert False, f"Unsupported {forward_batch.forward_mode = }"
         
     | 
| 
       315 
372 
     | 
    
         | 
| 
         @@ -328,7 +385,9 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       328 
385 
     | 
    
         
             
                        max_seq_len_k=max_seqlen_k,
         
     | 
| 
       329 
386 
     | 
    
         
             
                        cu_seqlens_q=cu_seqlens_q,
         
     | 
| 
       330 
387 
     | 
    
         
             
                        cu_seqlens_k=cu_seqlens_k,
         
     | 
| 
      
 388 
     | 
    
         
            +
                        seq_lens_sum=forward_batch.seq_lens_sum,
         
     | 
| 
       331 
389 
     | 
    
         
             
                        page_table_1=page_table,
         
     | 
| 
      
 390 
     | 
    
         
            +
                        page_table_1_flattened=page_table_1_flattened,
         
     | 
| 
       332 
391 
     | 
    
         
             
                        flashmla_metadata=(
         
     | 
| 
       333 
392 
     | 
    
         
             
                            self._compute_flashmla_metadata(
         
     | 
| 
       334 
393 
     | 
    
         
             
                                cache_seqlens=nsa_cache_seqlens_int32,
         
     | 
| 
         @@ -344,6 +403,7 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       344 
403 
     | 
    
         
             
                        nsa_extend_seq_lens_list=extend_seq_lens_cpu,
         
     | 
| 
       345 
404 
     | 
    
         
             
                        real_page_table=self._transform_table_1_to_real(page_table),
         
     | 
| 
       346 
405 
     | 
    
         
             
                        nsa_max_seqlen_q=1,
         
     | 
| 
      
 406 
     | 
    
         
            +
                        topk_indices_offset=topk_indices_offset,
         
     | 
| 
       347 
407 
     | 
    
         
             
                    )
         
     | 
| 
       348 
408 
     | 
    
         | 
| 
       349 
409 
     | 
    
         
             
                    self.forward_metadata = metadata
         
     | 
| 
         @@ -396,6 +456,8 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       396 
456 
     | 
    
         
             
                    forward_mode: ForwardMode,
         
     | 
| 
       397 
457 
     | 
    
         
             
                    spec_info: Optional[SpecInput],
         
     | 
| 
       398 
458 
     | 
    
         
             
                ):
         
     | 
| 
      
 459 
     | 
    
         
            +
                    self.set_nsa_prefill_impl(forward_batch=None)
         
     | 
| 
      
 460 
     | 
    
         
            +
             
     | 
| 
       399 
461 
     | 
    
         
             
                    """Initialize forward metadata for capturing CUDA graph."""
         
     | 
| 
       400 
462 
     | 
    
         
             
                    if forward_mode.is_decode_or_idle():
         
     | 
| 
       401 
463 
     | 
    
         
             
                        # Normal Decode
         
     | 
| 
         @@ -586,6 +648,8 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       586 
648 
     | 
    
         
             
                    """Initialize forward metadata for replaying CUDA graph."""
         
     | 
| 
       587 
649 
     | 
    
         
             
                    assert seq_lens_cpu is not None
         
     | 
| 
       588 
650 
     | 
    
         | 
| 
      
 651 
     | 
    
         
            +
                    self.set_nsa_prefill_impl(forward_batch=None)
         
     | 
| 
      
 652 
     | 
    
         
            +
             
     | 
| 
       589 
653 
     | 
    
         
             
                    seq_lens = seq_lens[:bs]
         
     | 
| 
       590 
654 
     | 
    
         
             
                    seq_lens_cpu = seq_lens_cpu[:bs]
         
     | 
| 
       591 
655 
     | 
    
         
             
                    req_pool_indices = req_pool_indices[:bs]
         
     | 
| 
         @@ -780,17 +844,31 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       780 
844 
     | 
    
         
             
                        q_rope = q_all[:, :, layer.v_head_dim :]
         
     | 
| 
       781 
845 
     | 
    
         | 
| 
       782 
846 
     | 
    
         
             
                    # NOTE(dark): here, we use page size = 1
         
     | 
| 
       783 
     | 
    
         
            -
             
     | 
| 
      
 847 
     | 
    
         
            +
                    topk_transform_method = self.get_topk_transform_method()
         
     | 
| 
       784 
848 
     | 
    
         
             
                    if NSA_FUSE_TOPK:
         
     | 
| 
       785 
849 
     | 
    
         
             
                        page_table_1 = topk_indices
         
     | 
| 
       786 
850 
     | 
    
         
             
                    else:
         
     | 
| 
       787 
     | 
    
         
            -
                         
     | 
| 
       788 
     | 
    
         
            -
             
     | 
| 
       789 
     | 
    
         
            -
                             
     | 
| 
       790 
     | 
    
         
            -
                             
     | 
| 
       791 
     | 
    
         
            -
                             
     | 
| 
       792 
     | 
    
         
            -
             
     | 
| 
       793 
     | 
    
         
            -
             
     | 
| 
      
 851 
     | 
    
         
            +
                        if topk_transform_method == TopkTransformMethod.RAGGED:
         
     | 
| 
      
 852 
     | 
    
         
            +
                            topk_indices_offset = metadata.topk_indices_offset
         
     | 
| 
      
 853 
     | 
    
         
            +
                            assert topk_indices_offset is not None
         
     | 
| 
      
 854 
     | 
    
         
            +
                            mask = topk_indices != -1
         
     | 
| 
      
 855 
     | 
    
         
            +
                            topk_indices_offset = (
         
     | 
| 
      
 856 
     | 
    
         
            +
                                topk_indices_offset.unsqueeze(1)
         
     | 
| 
      
 857 
     | 
    
         
            +
                                if topk_indices_offset.ndim == 1
         
     | 
| 
      
 858 
     | 
    
         
            +
                                else topk_indices_offset
         
     | 
| 
      
 859 
     | 
    
         
            +
                            )
         
     | 
| 
      
 860 
     | 
    
         
            +
                            topk_indices = torch.where(
         
     | 
| 
      
 861 
     | 
    
         
            +
                                mask, topk_indices + topk_indices_offset, topk_indices
         
     | 
| 
      
 862 
     | 
    
         
            +
                            )
         
     | 
| 
      
 863 
     | 
    
         
            +
                        elif topk_transform_method == TopkTransformMethod.PAGED:
         
     | 
| 
      
 864 
     | 
    
         
            +
                            assert metadata.nsa_extend_seq_lens_list is not None
         
     | 
| 
      
 865 
     | 
    
         
            +
                            page_table_1 = transform_index_page_table_prefill(
         
     | 
| 
      
 866 
     | 
    
         
            +
                                page_table=metadata.page_table_1,
         
     | 
| 
      
 867 
     | 
    
         
            +
                                topk_indices=topk_indices,
         
     | 
| 
      
 868 
     | 
    
         
            +
                                extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
         
     | 
| 
      
 869 
     | 
    
         
            +
                                page_size=1,
         
     | 
| 
      
 870 
     | 
    
         
            +
                            )
         
     | 
| 
      
 871 
     | 
    
         
            +
             
     | 
| 
       794 
872 
     | 
    
         
             
                    if NSA_PREFILL_IMPL == "tilelang":
         
     | 
| 
       795 
873 
     | 
    
         
             
                        if q_rope is not None:
         
     | 
| 
       796 
874 
     | 
    
         
             
                            q_all = torch.cat([q_nope, q_rope], dim=-1)
         
     | 
| 
         @@ -804,6 +882,22 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       804 
882 
     | 
    
         
             
                    elif NSA_PREFILL_IMPL == "flashmla_sparse":
         
     | 
| 
       805 
883 
     | 
    
         
             
                        if q_rope is not None:
         
     | 
| 
       806 
884 
     | 
    
         
             
                            q_all = torch.cat([q_nope, q_rope], dim=-1)
         
     | 
| 
      
 885 
     | 
    
         
            +
             
     | 
| 
      
 886 
     | 
    
         
            +
                        # NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 has no effect here,
         
     | 
| 
      
 887 
     | 
    
         
            +
                        # because the flashmla_sparse kernel doesn't support fp8 compute
         
     | 
| 
      
 888 
     | 
    
         
            +
                        if topk_transform_method == TopkTransformMethod.RAGGED:
         
     | 
| 
      
 889 
     | 
    
         
            +
                            if any(forward_batch.extend_prefix_lens_cpu):
         
     | 
| 
      
 890 
     | 
    
         
            +
                                page_table_1_flattened = (
         
     | 
| 
      
 891 
     | 
    
         
            +
                                    self.forward_metadata.page_table_1_flattened
         
     | 
| 
      
 892 
     | 
    
         
            +
                                )
         
     | 
| 
      
 893 
     | 
    
         
            +
                                assert page_table_1_flattened is not None
         
     | 
| 
      
 894 
     | 
    
         
            +
                                kv_cache = dequantize_k_cache_paged(
         
     | 
| 
      
 895 
     | 
    
         
            +
                                    kv_cache, page_table_1_flattened
         
     | 
| 
      
 896 
     | 
    
         
            +
                                )
         
     | 
| 
      
 897 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 898 
     | 
    
         
            +
                                kv_cache = torch.cat([k, k_rope], dim=-1)
         
     | 
| 
      
 899 
     | 
    
         
            +
                            page_table_1 = topk_indices
         
     | 
| 
      
 900 
     | 
    
         
            +
             
     | 
| 
       807 
901 
     | 
    
         
             
                        return self._forward_flashmla_sparse(
         
     | 
| 
       808 
902 
     | 
    
         
             
                            q_all=q_all,
         
     | 
| 
       809 
903 
     | 
    
         
             
                            kv_cache=kv_cache,
         
     | 
| 
         @@ -1004,7 +1098,7 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       1004 
1098 
     | 
    
         
             
                    page_table_1: torch.Tensor,
         
     | 
| 
       1005 
1099 
     | 
    
         
             
                    sm_scale: float,
         
     | 
| 
       1006 
1100 
     | 
    
         
             
                ) -> torch.Tensor:
         
     | 
| 
       1007 
     | 
    
         
            -
                    from flash_mla import flash_mla_sparse_fwd
         
     | 
| 
      
 1101 
     | 
    
         
            +
                    from sgl_kernel.flash_mla import flash_mla_sparse_fwd
         
     | 
| 
       1008 
1102 
     | 
    
         | 
| 
       1009 
1103 
     | 
    
         
             
                    o, _, _ = flash_mla_sparse_fwd(
         
     | 
| 
       1010 
1104 
     | 
    
         
             
                        q=q_all,
         
     | 
| 
         @@ -1025,7 +1119,7 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       1025 
1119 
     | 
    
         
             
                    metadata: NSAMetadata,
         
     | 
| 
       1026 
1120 
     | 
    
         
             
                    page_table_1,
         
     | 
| 
       1027 
1121 
     | 
    
         
             
                ) -> torch.Tensor:
         
     | 
| 
       1028 
     | 
    
         
            -
                    from flash_mla import flash_mla_with_kvcache
         
     | 
| 
      
 1122 
     | 
    
         
            +
                    from sgl_kernel.flash_mla import flash_mla_with_kvcache
         
     | 
| 
       1029 
1123 
     | 
    
         | 
| 
       1030 
1124 
     | 
    
         
             
                    cache_seqlens = metadata.nsa_cache_seqlens_int32
         
     | 
| 
       1031 
1125 
     | 
    
         | 
| 
         @@ -1121,13 +1215,53 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       1121 
1215 
     | 
    
         
             
                    """Get the fill value for sequence length in CUDA graph."""
         
     | 
| 
       1122 
1216 
     | 
    
         
             
                    return 1
         
     | 
| 
       1123 
1217 
     | 
    
         | 
| 
      
 1218 
     | 
    
         
            +
                def set_nsa_prefill_impl(self, forward_batch: Optional[ForwardBatch] = None) -> str:
         
     | 
| 
      
 1219 
     | 
    
         
            +
                    from sglang.srt.utils import is_blackwell
         
     | 
| 
      
 1220 
     | 
    
         
            +
             
     | 
| 
      
 1221 
     | 
    
         
            +
                    global NSA_PREFILL_IMPL
         
     | 
| 
      
 1222 
     | 
    
         
            +
                    if self.enable_auto_select_prefill_impl:
         
     | 
| 
      
 1223 
     | 
    
         
            +
                        if self.nsa_kv_cache_store_fp8:
         
     | 
| 
      
 1224 
     | 
    
         
            +
                            if (
         
     | 
| 
      
 1225 
     | 
    
         
            +
                                is_blackwell()
         
     | 
| 
      
 1226 
     | 
    
         
            +
                                and forward_batch is not None
         
     | 
| 
      
 1227 
     | 
    
         
            +
                                and forward_batch.forward_mode == ForwardMode.EXTEND
         
     | 
| 
      
 1228 
     | 
    
         
            +
                            ):
         
     | 
| 
      
 1229 
     | 
    
         
            +
                                total_kv_tokens = forward_batch.seq_lens_sum
         
     | 
| 
      
 1230 
     | 
    
         
            +
                                total_q_tokens = forward_batch.extend_num_tokens
         
     | 
| 
      
 1231 
     | 
    
         
            +
                                # Heuristic based on benchmarking flashmla_kv vs flashmla_sparse + dequantize_k_cache_paged
         
     | 
| 
      
 1232 
     | 
    
         
            +
                                if total_kv_tokens < total_q_tokens * 512:
         
     | 
| 
      
 1233 
     | 
    
         
            +
                                    NSA_PREFILL_IMPL = "flashmla_sparse"
         
     | 
| 
      
 1234 
     | 
    
         
            +
                                    return
         
     | 
| 
      
 1235 
     | 
    
         
            +
                            NSA_PREFILL_IMPL = "flashmla_kv"
         
     | 
| 
      
 1236 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 1237 
     | 
    
         
            +
                            # bf16 kv cache
         
     | 
| 
      
 1238 
     | 
    
         
            +
                            NSA_PREFILL_IMPL = "flashmla_sparse"
         
     | 
| 
      
 1239 
     | 
    
         
            +
             
     | 
| 
      
 1240 
     | 
    
         
            +
                def get_topk_transform_method(self) -> TopkTransformMethod:
         
     | 
| 
      
 1241 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1242 
     | 
    
         
            +
                    NSA_FUSE_TOPK controls whether to fuse the topk transform into the topk kernel.
         
     | 
| 
      
 1243 
     | 
    
         
            +
                    This method is used to select the topk transform method which can be fused or unfused.
         
     | 
| 
      
 1244 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1245 
     | 
    
         
            +
                    if (
         
     | 
| 
      
 1246 
     | 
    
         
            +
                        # disable for MTP
         
     | 
| 
      
 1247 
     | 
    
         
            +
                        self.nsa_kv_cache_store_fp8
         
     | 
| 
      
 1248 
     | 
    
         
            +
                        and NSA_PREFILL_IMPL == "flashmla_sparse"
         
     | 
| 
      
 1249 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 1250 
     | 
    
         
            +
                        topk_transform_method = TopkTransformMethod.RAGGED
         
     | 
| 
      
 1251 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 1252 
     | 
    
         
            +
                        topk_transform_method = TopkTransformMethod.PAGED
         
     | 
| 
      
 1253 
     | 
    
         
            +
                    return topk_transform_method
         
     | 
| 
      
 1254 
     | 
    
         
            +
             
     | 
| 
       1124 
1255 
     | 
    
         
             
                def get_indexer_metadata(
         
     | 
| 
       1125 
1256 
     | 
    
         
             
                    self, layer_id: int, forward_batch: ForwardBatch
         
     | 
| 
       1126 
1257 
     | 
    
         
             
                ) -> NSAIndexerMetadata:
         
     | 
| 
       1127 
     | 
    
         
            -
                    return NSAIndexerMetadata( 
     | 
| 
      
 1258 
     | 
    
         
            +
                    return NSAIndexerMetadata(
         
     | 
| 
      
 1259 
     | 
    
         
            +
                        attn_metadata=self.forward_metadata,
         
     | 
| 
      
 1260 
     | 
    
         
            +
                        topk_transform_method=self.get_topk_transform_method(),
         
     | 
| 
      
 1261 
     | 
    
         
            +
                    )
         
     | 
| 
       1128 
1262 
     | 
    
         | 
| 
       1129 
1263 
     | 
    
         
             
                def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
         
     | 
| 
       1130 
     | 
    
         
            -
                    from flash_mla import get_mla_metadata
         
     | 
| 
      
 1264 
     | 
    
         
            +
                    from sgl_kernel.flash_mla import get_mla_metadata
         
     | 
| 
       1131 
1265 
     | 
    
         | 
| 
       1132 
1266 
     | 
    
         
             
                    flashmla_metadata, num_splits = get_mla_metadata(
         
     | 
| 
       1133 
1267 
     | 
    
         
             
                        cache_seqlens=cache_seqlens,
         
     | 
| 
         @@ -92,7 +92,10 @@ class TritonAttnBackend(AttentionBackend): 
     | 
|
| 
       92 
92 
     | 
    
         
             
                    self.num_kv_head = model_runner.model_config.get_num_kv_heads(
         
     | 
| 
       93 
93 
     | 
    
         
             
                        get_attention_tp_size()
         
     | 
| 
       94 
94 
     | 
    
         
             
                    )
         
     | 
| 
       95 
     | 
    
         
            -
                    if  
     | 
| 
      
 95 
     | 
    
         
            +
                    if (
         
     | 
| 
      
 96 
     | 
    
         
            +
                        model_runner.hybrid_gdn_config is not None
         
     | 
| 
      
 97 
     | 
    
         
            +
                        or model_runner.kimi_linear_config is not None
         
     | 
| 
      
 98 
     | 
    
         
            +
                    ):
         
     | 
| 
       96 
99 
     | 
    
         
             
                        # For hybrid linear models, layer_id = 0 may not be full attention
         
     | 
| 
       97 
100 
     | 
    
         
             
                        self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
         
     | 
| 
       98 
101 
     | 
    
         
             
                    else:
         
     | 
| 
         @@ -488,10 +488,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): 
     | 
|
| 
       488 
488 
     | 
    
         
             
                            forward_batch.req_pool_indices, : metadata.max_seq_len_k
         
     | 
| 
       489 
489 
     | 
    
         
             
                        ]
         
     | 
| 
       490 
490 
     | 
    
         | 
| 
       491 
     | 
    
         
            -
                        if (
         
     | 
| 
       492 
     | 
    
         
            -
                             
     | 
| 
       493 
     | 
    
         
            -
             
     | 
| 
       494 
     | 
    
         
            -
                        ):
         
     | 
| 
      
 491 
     | 
    
         
            +
                        if any(
         
     | 
| 
      
 492 
     | 
    
         
            +
                            forward_batch.extend_prefix_lens_cpu
         
     | 
| 
      
 493 
     | 
    
         
            +
                        ) or forward_batch.forward_mode.is_draft_extend(include_v2=True):
         
     | 
| 
       495 
494 
     | 
    
         
             
                            extend_seq_lens = forward_batch.extend_seq_lens
         
     | 
| 
       496 
495 
     | 
    
         
             
                            metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
         
     | 
| 
       497 
496 
     | 
    
         
             
                            metadata.cu_seqlens_q = torch.nn.functional.pad(
         
     | 
| 
         @@ -529,6 +528,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): 
     | 
|
| 
       529 
528 
     | 
    
         
             
                            layer, cache_loc, k, v, layer.k_scale, layer.v_scale
         
     | 
| 
       530 
529 
     | 
    
         
             
                        )
         
     | 
| 
       531 
530 
     | 
    
         | 
| 
      
 531 
     | 
    
         
            +
                    if self.data_type == torch.float8_e4m3fn:
         
     | 
| 
      
 532 
     | 
    
         
            +
                        q = q.to(torch.float8_e4m3fn)
         
     | 
| 
       532 
533 
     | 
    
         
             
                    q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
         
     | 
| 
       533 
534 
     | 
    
         
             
                    k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
         
     | 
| 
       534 
535 
     | 
    
         
             
                    # shape conversion:
         
     | 
| 
         @@ -567,6 +568,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): 
     | 
|
| 
       567 
568 
     | 
    
         
             
                        window_left=layer.sliding_window_size,
         
     | 
| 
       568 
569 
     | 
    
         
             
                        # TODO: add attention_sink operation or nvfp4 scale factor if needed
         
     | 
| 
       569 
570 
     | 
    
         
             
                        sinks=attention_sink,
         
     | 
| 
      
 571 
     | 
    
         
            +
                        out_dtype=self.q_data_type,  # model_runner.dtype
         
     | 
| 
       570 
572 
     | 
    
         
             
                    )
         
     | 
| 
       571 
573 
     | 
    
         | 
| 
       572 
574 
     | 
    
         
             
                    return o.view(-1, layer.tp_q_head_num * layer.head_dim)
         
     | 
| 
         @@ -586,6 +588,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): 
     | 
|
| 
       586 
588 
     | 
    
         
             
                        forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
       587 
589 
     | 
    
         
             
                            layer, cache_loc, k, v, layer.k_scale, layer.v_scale
         
     | 
| 
       588 
590 
     | 
    
         
             
                        )
         
     | 
| 
      
 591 
     | 
    
         
            +
             
     | 
| 
      
 592 
     | 
    
         
            +
                    if self.data_type == torch.float8_e4m3fn:
         
     | 
| 
      
 593 
     | 
    
         
            +
                        q = q.to(torch.float8_e4m3fn)
         
     | 
| 
       589 
594 
     | 
    
         
             
                    q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
         
     | 
| 
       590 
595 
     | 
    
         
             
                    # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
         
     | 
| 
       591 
596 
     | 
    
         
             
                    k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
         
     | 
| 
         @@ -625,6 +630,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): 
     | 
|
| 
       625 
630 
     | 
    
         
             
                        window_left=layer.sliding_window_size,
         
     | 
| 
       626 
631 
     | 
    
         
             
                        # TODO: add attention_sink operation or nvfp4 scale factor if needed
         
     | 
| 
       627 
632 
     | 
    
         
             
                        sinks=attention_sink,
         
     | 
| 
      
 633 
     | 
    
         
            +
                        out_dtype=self.q_data_type,  # model_runner.dtype
         
     | 
| 
       628 
634 
     | 
    
         
             
                    )
         
     | 
| 
       629 
635 
     | 
    
         | 
| 
       630 
636 
     | 
    
         
             
                    return o.view(-1, layer.tp_q_head_num * layer.head_dim)
         
     | 
| 
         @@ -423,14 +423,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): 
     | 
|
| 
       423 
423 
     | 
    
         
             
                        PAGED_SIZE=self.page_size,
         
     | 
| 
       424 
424 
     | 
    
         
             
                    )
         
     | 
| 
       425 
425 
     | 
    
         | 
| 
       426 
     | 
    
         
            -
                    # Record the true maximum sequence length for this capture batch so that
         
     | 
| 
       427 
     | 
    
         
            -
                    # the kernel launch path (which requires an int not a tensor) can reuse
         
     | 
| 
       428 
     | 
    
         
            -
                    # it safely during both capture and replay.
         
     | 
| 
       429 
     | 
    
         
            -
                    max_seq_len_val = int(seq_lens.max().item())
         
     | 
| 
       430 
     | 
    
         
            -
             
     | 
| 
       431 
426 
     | 
    
         
             
                    metadata = TRTLLMMLADecodeMetadata(
         
     | 
| 
       432 
427 
     | 
    
         
             
                        block_kv_indices,
         
     | 
| 
       433 
     | 
    
         
            -
                         
     | 
| 
      
 428 
     | 
    
         
            +
                        self.max_context_len,
         
     | 
| 
       434 
429 
     | 
    
         
             
                    )
         
     | 
| 
       435 
430 
     | 
    
         
             
                    if forward_mode.is_draft_extend(include_v2=True):
         
     | 
| 
       436 
431 
     | 
    
         
             
                        num_tokens_per_bs = num_tokens // bs
         
     | 
| 
         @@ -509,13 +504,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): 
     | 
|
| 
       509 
504 
     | 
    
         
             
                        PAGED_SIZE=self.page_size,
         
     | 
| 
       510 
505 
     | 
    
         
             
                    )
         
     | 
| 
       511 
506 
     | 
    
         | 
| 
       512 
     | 
    
         
            -
                    # Update stored max_seq_len so subsequent kernel calls use the correct value
         
     | 
| 
       513 
     | 
    
         
            -
                    # Prefer CPU tensor to avoid GPU synchronization when available.
         
     | 
| 
       514 
     | 
    
         
            -
                    if seq_lens_cpu is not None:
         
     | 
| 
       515 
     | 
    
         
            -
                        metadata.max_seq_len = int(seq_lens_cpu.max().item())
         
     | 
| 
       516 
     | 
    
         
            -
                    else:
         
     | 
| 
       517 
     | 
    
         
            -
                        metadata.max_seq_len = int(seq_lens.max().item())
         
     | 
| 
       518 
     | 
    
         
            -
             
     | 
| 
       519 
507 
     | 
    
         
             
                def get_cuda_graph_seq_len_fill_value(self) -> int:
         
     | 
| 
       520 
508 
     | 
    
         
             
                    """Get the fill value for sequence lengths in CUDA graph."""
         
     | 
| 
       521 
509 
     | 
    
         
             
                    return 1
         
     | 
| 
         @@ -956,8 +944,16 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): 
     | 
|
| 
       956 
944 
     | 
    
         
             
                                metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num
         
     | 
| 
       957 
945 
     | 
    
         
             
                            )
         
     | 
| 
       958 
946 
     | 
    
         
             
                        else:
         
     | 
| 
       959 
     | 
    
         
            -
                             
     | 
| 
       960 
     | 
    
         
            -
                             
     | 
| 
      
 947 
     | 
    
         
            +
                            # forward_batch.seq_lens is the seq_lens of the prev_context + verified tokens.
         
     | 
| 
      
 948 
     | 
    
         
            +
                            # To account for pad_draft_extend_query, we need seq_lens = prev_context + max_draft_tokens.
         
     | 
| 
      
 949 
     | 
    
         
            +
                            # This will ensure queries align with kvs correctly when calling
         
     | 
| 
      
 950 
     | 
    
         
            +
                            # flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla.
         
     | 
| 
      
 951 
     | 
    
         
            +
                            seq_lens = (
         
     | 
| 
      
 952 
     | 
    
         
            +
                                forward_batch.seq_lens
         
     | 
| 
      
 953 
     | 
    
         
            +
                                - metadata.seq_lens_q
         
     | 
| 
      
 954 
     | 
    
         
            +
                                + metadata.max_seq_len_q
         
     | 
| 
      
 955 
     | 
    
         
            +
                            ).to(torch.int32)
         
     | 
| 
      
 956 
     | 
    
         
            +
                            max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q
         
     | 
| 
       961 
957 
     | 
    
         
             
                            # Check if we're in CUDA graph mode (buffers are pre-allocated)
         
     | 
| 
       962 
958 
     | 
    
         
             
                            if self.padded_q_buffer is not None:
         
     | 
| 
       963 
959 
     | 
    
         
             
                                # Use pre-allocated buffer for CUDA graph compatibility
         
     |