sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +18 -3
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +25 -2
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -5
 - sglang/srt/entrypoints/engine.py +13 -5
 - sglang/srt/entrypoints/http_server.py +22 -3
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +7 -0
 - sglang/srt/eplb/expert_distribution.py +34 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +7 -6
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
 - sglang/srt/layers/communicator.py +23 -1
 - sglang/srt/layers/layernorm.py +16 -2
 - sglang/srt/layers/logits_processor.py +4 -20
 - sglang/srt/layers/moe/ep_moe/layer.py +0 -18
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
 - sglang/srt/layers/moe/topk.py +31 -6
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +9 -78
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/rotary_embedding.py +117 -45
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +26 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +164 -129
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +154 -59
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/memory_pool.py +171 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +11 -11
 - sglang/srt/model_executor/model_runner.py +76 -21
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +149 -34
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +0 -1
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +1 -1
 - sglang/srt/models/qwen2.py +1 -1
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +1 -1
 - sglang/srt/models/qwen3_moe.py +16 -8
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +103 -22
 - sglang/srt/single_batch_overlap.py +4 -1
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +55 -32
 - sglang/srt/utils/hf_transformers_utils.py +38 -16
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -206,6 +206,8 @@ def _quantize_k_cache_fast_kernel( 
     | 
|
| 
       206 
206 
     | 
    
         | 
| 
       207 
207 
     | 
    
         | 
| 
       208 
208 
     | 
    
         
             
            if __name__ == "__main__":
         
     | 
| 
      
 209 
     | 
    
         
            +
                import dequant_k_cache
         
     | 
| 
      
 210 
     | 
    
         
            +
             
     | 
| 
       209 
211 
     | 
    
         
             
                for num_blocks, block_size in [
         
     | 
| 
       210 
212 
     | 
    
         
             
                    (1, 1),
         
     | 
| 
       211 
213 
     | 
    
         
             
                    (10, 64),
         
     | 
| 
         @@ -217,21 +219,9 @@ if __name__ == "__main__": 
     | 
|
| 
       217 
219 
     | 
    
         
             
                        dtype=torch.bfloat16,
         
     | 
| 
       218 
220 
     | 
    
         
             
                        device="cuda",
         
     | 
| 
       219 
221 
     | 
    
         
             
                    )
         
     | 
| 
       220 
     | 
    
         
            -
                    # temp debug
         
     | 
| 
       221 
     | 
    
         
            -
                    # input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
         
     | 
| 
       222 
222 
     | 
    
         | 
| 
       223 
223 
     | 
    
         
             
                    ref_quant = _quantize_k_cache_slow(input_k_cache)
         
     | 
| 
       224 
224 
     | 
    
         
             
                    actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
         
     | 
| 
       225 
     | 
    
         
            -
                    # print(f"{input_k_cache=}")
         
     | 
| 
       226 
     | 
    
         
            -
                    # print(f"{ref_quant=}")
         
     | 
| 
       227 
     | 
    
         
            -
                    # print(f"{actual_quant=}")
         
     | 
| 
       228 
     | 
    
         
            -
                    # print(f"{ref_quant == actual_quant=}")
         
     | 
| 
       229 
     | 
    
         
            -
                    # print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
         
     | 
| 
       230 
     | 
    
         
            -
                    # print(f"{ref_quant.view(torch.bfloat16)=}")
         
     | 
| 
       231 
     | 
    
         
            -
                    # print(f"{actual_quant.view(torch.bfloat16)=}")
         
     | 
| 
       232 
     | 
    
         
            -
                    # assert torch.all(ref_quant == actual_quant)
         
     | 
| 
       233 
     | 
    
         
            -
             
     | 
| 
       234 
     | 
    
         
            -
                    import dequant_k_cache
         
     | 
| 
       235 
225 
     | 
    
         | 
| 
       236 
226 
     | 
    
         
             
                    ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)
         
     | 
| 
       237 
227 
     | 
    
         
             
                    ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)
         
     | 
| 
         @@ -252,4 +242,46 @@ if __name__ == "__main__": 
     | 
|
| 
       252 
242 
     | 
    
         
             
                        ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2
         
     | 
| 
       253 
243 
     | 
    
         
             
                    )
         
     | 
| 
       254 
244 
     | 
    
         | 
| 
      
 245 
     | 
    
         
            +
                    # test dequant_k_cache_paged
         
     | 
| 
      
 246 
     | 
    
         
            +
                    page_table_1 = torch.arange(
         
     | 
| 
      
 247 
     | 
    
         
            +
                        num_blocks * block_size, dtype=torch.int32, device="cuda"
         
     | 
| 
      
 248 
     | 
    
         
            +
                    )
         
     | 
| 
      
 249 
     | 
    
         
            +
                    actual_dequant_paged = dequant_k_cache.dequantize_k_cache_paged(
         
     | 
| 
      
 250 
     | 
    
         
            +
                        actual_quant, page_table_1
         
     | 
| 
      
 251 
     | 
    
         
            +
                    ).reshape(actual_actual_dequant.shape)
         
     | 
| 
      
 252 
     | 
    
         
            +
                    print(f"{torch.mean(actual_actual_dequant - actual_dequant_paged)=}")
         
     | 
| 
      
 253 
     | 
    
         
            +
                    torch.testing.assert_close(
         
     | 
| 
      
 254 
     | 
    
         
            +
                        ref_ref_dequant, actual_dequant_paged, atol=0.2, rtol=0.2
         
     | 
| 
      
 255 
     | 
    
         
            +
                    )
         
     | 
| 
      
 256 
     | 
    
         
            +
             
     | 
| 
       255 
257 
     | 
    
         
             
                print("Passed")
         
     | 
| 
      
 258 
     | 
    
         
            +
                print("Do benchmark...")
         
     | 
| 
      
 259 
     | 
    
         
            +
             
     | 
| 
      
 260 
     | 
    
         
            +
                for num_blocks, block_size in [
         
     | 
| 
      
 261 
     | 
    
         
            +
                    (1, 64),
         
     | 
| 
      
 262 
     | 
    
         
            +
                    (64, 64),
         
     | 
| 
      
 263 
     | 
    
         
            +
                    (128, 64),
         
     | 
| 
      
 264 
     | 
    
         
            +
                    (256, 64),
         
     | 
| 
      
 265 
     | 
    
         
            +
                    (512, 64),
         
     | 
| 
      
 266 
     | 
    
         
            +
                    (1024, 64),
         
     | 
| 
      
 267 
     | 
    
         
            +
                    (2048, 64),
         
     | 
| 
      
 268 
     | 
    
         
            +
                ]:
         
     | 
| 
      
 269 
     | 
    
         
            +
                    dim_nope_and_rope = 512 + 64
         
     | 
| 
      
 270 
     | 
    
         
            +
             
     | 
| 
      
 271 
     | 
    
         
            +
                    input_k_cache = torch.randn(
         
     | 
| 
      
 272 
     | 
    
         
            +
                        (num_blocks, block_size, 1, dim_nope_and_rope),
         
     | 
| 
      
 273 
     | 
    
         
            +
                        dtype=torch.bfloat16,
         
     | 
| 
      
 274 
     | 
    
         
            +
                        device="cuda",
         
     | 
| 
      
 275 
     | 
    
         
            +
                    )
         
     | 
| 
      
 276 
     | 
    
         
            +
             
     | 
| 
      
 277 
     | 
    
         
            +
                    actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
         
     | 
| 
      
 278 
     | 
    
         
            +
             
     | 
| 
      
 279 
     | 
    
         
            +
                    page_table_1 = torch.arange(
         
     | 
| 
      
 280 
     | 
    
         
            +
                        num_blocks * block_size, dtype=torch.int32, device="cuda"
         
     | 
| 
      
 281 
     | 
    
         
            +
                    )
         
     | 
| 
      
 282 
     | 
    
         
            +
             
     | 
| 
      
 283 
     | 
    
         
            +
                    def run_ans():
         
     | 
| 
      
 284 
     | 
    
         
            +
                        return dequant_k_cache.dequantize_k_cache_paged(actual_quant, page_table_1)
         
     | 
| 
      
 285 
     | 
    
         
            +
             
     | 
| 
      
 286 
     | 
    
         
            +
                    ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000  # type: ignore
         
     | 
| 
      
 287 
     | 
    
         
            +
                    print(f"seq_kv: {num_blocks * block_size}, time: {ans_time * 1e6: 4.0f} us")
         
     | 
| 
         @@ -103,7 +103,7 @@ def transform_index_page_table_decode_ref( 
     | 
|
| 
       103 
103 
     | 
    
         
             
                    result = torch.empty_like(topk_indices, dtype=torch.int32)
         
     | 
| 
       104 
104 
     | 
    
         
             
                assert result.shape == topk_indices.shape
         
     | 
| 
       105 
105 
     | 
    
         
             
                torch.gather(
         
     | 
| 
       106 
     | 
    
         
            -
                    page_table,
         
     | 
| 
      
 106 
     | 
    
         
            +
                    page_table.to(result.dtype),
         
     | 
| 
       107 
107 
     | 
    
         
             
                    dim=1,
         
     | 
| 
       108 
108 
     | 
    
         
             
                    index=topk_indices.clamp(min=0),
         
     | 
| 
       109 
109 
     | 
    
         
             
                    out=result,
         
     | 
| 
         @@ -1,12 +1,14 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from __future__ import annotations
         
     | 
| 
       2 
2 
     | 
    
         | 
| 
       3 
3 
     | 
    
         
             
            from dataclasses import dataclass
         
     | 
| 
      
 4 
     | 
    
         
            +
            from enum import IntEnum, auto
         
     | 
| 
       4 
5 
     | 
    
         
             
            from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
         
     | 
| 
       5 
6 
     | 
    
         | 
| 
       6 
7 
     | 
    
         
             
            import torch
         
     | 
| 
       7 
8 
     | 
    
         | 
| 
       8 
9 
     | 
    
         
             
            from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
         
     | 
| 
       9 
10 
     | 
    
         
             
            from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
         
     | 
| 
      
 11 
     | 
    
         
            +
            from sglang.srt.layers.attention.nsa.dequant_k_cache import dequantize_k_cache_paged
         
     | 
| 
       10 
12 
     | 
    
         
             
            from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
         
     | 
| 
       11 
13 
     | 
    
         
             
            from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
         
     | 
| 
       12 
14 
     | 
    
         
             
            from sglang.srt.layers.attention.nsa.transform_index import (
         
     | 
| 
         @@ -98,11 +100,27 @@ class NSAMetadata: 
     | 
|
| 
       98 
100 
     | 
    
         
             
                nsa_max_seqlen_q: Literal[1] = 1  # always 1 for decode, variable for extend
         
     | 
| 
       99 
101 
     | 
    
         | 
| 
       100 
102 
     | 
    
         
             
                flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
         
     | 
| 
      
 103 
     | 
    
         
            +
                # The sum of sequence lengths for key, prefill only
         
     | 
| 
      
 104 
     | 
    
         
            +
                seq_lens_sum: Optional[int] = None
         
     | 
| 
      
 105 
     | 
    
         
            +
                # The flattened 1D page table with shape (seq_lens_sum,), prefill only
         
     | 
| 
      
 106 
     | 
    
         
            +
                # this table is always with page_size = 1
         
     | 
| 
      
 107 
     | 
    
         
            +
                page_table_1_flattened: Optional[torch.Tensor] = None
         
     | 
| 
      
 108 
     | 
    
         
            +
                # The offset of topk indices in ragged kv, prefill only
         
     | 
| 
      
 109 
     | 
    
         
            +
                # shape: (seq_lens_sum,)
         
     | 
| 
      
 110 
     | 
    
         
            +
                topk_indices_offset: Optional[torch.Tensor] = None
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
      
 113 
     | 
    
         
            +
            class TopkTransformMethod(IntEnum):
         
     | 
| 
      
 114 
     | 
    
         
            +
                # Transform topk indices to indices to the page table (page_size = 1)
         
     | 
| 
      
 115 
     | 
    
         
            +
                PAGED = auto()
         
     | 
| 
      
 116 
     | 
    
         
            +
                # Transform topk indices to indices to ragged kv (non-paged)
         
     | 
| 
      
 117 
     | 
    
         
            +
                RAGGED = auto()
         
     | 
| 
       101 
118 
     | 
    
         | 
| 
       102 
119 
     | 
    
         | 
| 
       103 
120 
     | 
    
         
             
            @dataclass(frozen=True)
         
     | 
| 
       104 
121 
     | 
    
         
             
            class NSAIndexerMetadata(BaseIndexerMetadata):
         
     | 
| 
       105 
122 
     | 
    
         
             
                attn_metadata: NSAMetadata
         
     | 
| 
      
 123 
     | 
    
         
            +
                topk_transform_method: TopkTransformMethod
         
     | 
| 
       106 
124 
     | 
    
         | 
| 
       107 
125 
     | 
    
         
             
                def get_seqlens_int32(self) -> torch.Tensor:
         
     | 
| 
       108 
126 
     | 
    
         
             
                    return self.attn_metadata.cache_seqlens_int32
         
     | 
| 
         @@ -118,23 +136,36 @@ class NSAIndexerMetadata(BaseIndexerMetadata): 
     | 
|
| 
       118 
136 
     | 
    
         
             
                    logits: torch.Tensor,
         
     | 
| 
       119 
137 
     | 
    
         
             
                    topk: int,
         
     | 
| 
       120 
138 
     | 
    
         
             
                ) -> torch.Tensor:
         
     | 
| 
       121 
     | 
    
         
            -
                    from sgl_kernel import  
     | 
| 
      
 139 
     | 
    
         
            +
                    from sgl_kernel import (
         
     | 
| 
      
 140 
     | 
    
         
            +
                        fast_topk_transform_fused,
         
     | 
| 
      
 141 
     | 
    
         
            +
                        fast_topk_transform_ragged_fused,
         
     | 
| 
      
 142 
     | 
    
         
            +
                        fast_topk_v2,
         
     | 
| 
      
 143 
     | 
    
         
            +
                    )
         
     | 
| 
       122 
144 
     | 
    
         | 
| 
       123 
145 
     | 
    
         
             
                    if not NSA_FUSE_TOPK:
         
     | 
| 
       124 
146 
     | 
    
         
             
                        return fast_topk_v2(logits, self.get_seqlens_expanded(), topk)
         
     | 
| 
       125 
     | 
    
         
            -
             
     | 
| 
       126 
     | 
    
         
            -
             
     | 
| 
       127 
     | 
    
         
            -
             
     | 
| 
       128 
     | 
    
         
            -
             
     | 
| 
       129 
     | 
    
         
            -
             
     | 
| 
       130 
     | 
    
         
            -
             
     | 
| 
       131 
     | 
    
         
            -
             
     | 
| 
       132 
     | 
    
         
            -
             
     | 
| 
       133 
     | 
    
         
            -
             
     | 
| 
      
 147 
     | 
    
         
            +
                    elif self.topk_transform_method == TopkTransformMethod.PAGED:
         
     | 
| 
      
 148 
     | 
    
         
            +
                        # NOTE(dark): if fused, we return a transformed page table directly
         
     | 
| 
      
 149 
     | 
    
         
            +
                        return fast_topk_transform_fused(
         
     | 
| 
      
 150 
     | 
    
         
            +
                            score=logits,
         
     | 
| 
      
 151 
     | 
    
         
            +
                            lengths=self.get_seqlens_expanded(),
         
     | 
| 
      
 152 
     | 
    
         
            +
                            page_table_size_1=self.attn_metadata.page_table_1,
         
     | 
| 
      
 153 
     | 
    
         
            +
                            cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
         
     | 
| 
      
 154 
     | 
    
         
            +
                            topk=topk,
         
     | 
| 
      
 155 
     | 
    
         
            +
                        )
         
     | 
| 
      
 156 
     | 
    
         
            +
                    elif self.topk_transform_method == TopkTransformMethod.RAGGED:
         
     | 
| 
      
 157 
     | 
    
         
            +
                        return fast_topk_transform_ragged_fused(
         
     | 
| 
      
 158 
     | 
    
         
            +
                            score=logits,
         
     | 
| 
      
 159 
     | 
    
         
            +
                            lengths=self.get_seqlens_expanded(),
         
     | 
| 
      
 160 
     | 
    
         
            +
                            topk_indices_offset=self.attn_metadata.topk_indices_offset,
         
     | 
| 
      
 161 
     | 
    
         
            +
                            topk=topk,
         
     | 
| 
      
 162 
     | 
    
         
            +
                        )
         
     | 
| 
      
 163 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 164 
     | 
    
         
            +
                        assert False, f"Unsupported {self.topk_transform_method = }"
         
     | 
| 
       134 
165 
     | 
    
         | 
| 
       135 
166 
     | 
    
         | 
| 
       136 
167 
     | 
    
         
             
            def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
         
     | 
| 
       137 
     | 
    
         
            -
                assert seqlens.dtype == torch.int32 
     | 
| 
      
 168 
     | 
    
         
            +
                assert seqlens.dtype == torch.int32
         
     | 
| 
       138 
169 
     | 
    
         
             
                return torch.nn.functional.pad(
         
     | 
| 
       139 
170 
     | 
    
         
             
                    torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
         
     | 
| 
       140 
171 
     | 
    
         
             
                )
         
     | 
| 
         @@ -181,6 +212,7 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       181 
212 
     | 
    
         
             
                    global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
         
     | 
| 
       182 
213 
     | 
    
         
             
                    NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend
         
     | 
| 
       183 
214 
     | 
    
         
             
                    NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend
         
     | 
| 
      
 215 
     | 
    
         
            +
                    self.enable_auto_select_prefill_impl = NSA_PREFILL_IMPL == "flashmla_auto"
         
     | 
| 
       184 
216 
     | 
    
         | 
| 
       185 
217 
     | 
    
         
             
                    self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
         
     | 
| 
       186 
218 
     | 
    
         | 
| 
         @@ -231,10 +263,16 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       231 
263 
     | 
    
         
             
                    cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
         
     | 
| 
       232 
264 
     | 
    
         
             
                    assert forward_batch.seq_lens_cpu is not None
         
     | 
| 
       233 
265 
     | 
    
         
             
                    max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item() + draft_token_num)
         
     | 
| 
      
 266 
     | 
    
         
            +
                    # [b, max_seqlen_k]
         
     | 
| 
       234 
267 
     | 
    
         
             
                    page_table = forward_batch.req_to_token_pool.req_to_token[
         
     | 
| 
       235 
268 
     | 
    
         
             
                        forward_batch.req_pool_indices, :max_seqlen_k
         
     | 
| 
       236 
269 
     | 
    
         
             
                    ]
         
     | 
| 
       237 
270 
     | 
    
         | 
| 
      
 271 
     | 
    
         
            +
                    page_table_1_flattened = None
         
     | 
| 
      
 272 
     | 
    
         
            +
                    topk_indices_offset = None
         
     | 
| 
      
 273 
     | 
    
         
            +
                    self.set_nsa_prefill_impl(forward_batch)
         
     | 
| 
      
 274 
     | 
    
         
            +
                    topk_transform_method = self.get_topk_transform_method()
         
     | 
| 
      
 275 
     | 
    
         
            +
             
     | 
| 
       238 
276 
     | 
    
         
             
                    if forward_batch.forward_mode.is_decode_or_idle():
         
     | 
| 
       239 
277 
     | 
    
         
             
                        extend_seq_lens_cpu = [1] * batch_size
         
     | 
| 
       240 
278 
     | 
    
         
             
                        max_seqlen_q = 1
         
     | 
| 
         @@ -295,6 +333,7 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       295 
333 
     | 
    
         
             
                        else:
         
     | 
| 
       296 
334 
     | 
    
         
             
                            max_seqlen_q = max_seqlen_k
         
     | 
| 
       297 
335 
     | 
    
         
             
                            cu_seqlens_q = cu_seqlens_k
         
     | 
| 
      
 336 
     | 
    
         
            +
             
     | 
| 
       298 
337 
     | 
    
         
             
                        seqlens_expanded = torch.cat(
         
     | 
| 
       299 
338 
     | 
    
         
             
                            [
         
     | 
| 
       300 
339 
     | 
    
         
             
                                torch.arange(
         
     | 
| 
         @@ -310,6 +349,24 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       310 
349 
     | 
    
         
             
                                )
         
     | 
| 
       311 
350 
     | 
    
         
             
                            ]
         
     | 
| 
       312 
351 
     | 
    
         
             
                        )
         
     | 
| 
      
 352 
     | 
    
         
            +
             
     | 
| 
      
 353 
     | 
    
         
            +
                        if topk_transform_method == TopkTransformMethod.RAGGED:
         
     | 
| 
      
 354 
     | 
    
         
            +
                            page_table_1_flattened = torch.cat(
         
     | 
| 
      
 355 
     | 
    
         
            +
                                [
         
     | 
| 
      
 356 
     | 
    
         
            +
                                    page_table[i, :kv_len]
         
     | 
| 
      
 357 
     | 
    
         
            +
                                    for i, kv_len in enumerate(
         
     | 
| 
      
 358 
     | 
    
         
            +
                                        forward_batch.seq_lens_cpu.tolist(),
         
     | 
| 
      
 359 
     | 
    
         
            +
                                    )
         
     | 
| 
      
 360 
     | 
    
         
            +
                                ]
         
     | 
| 
      
 361 
     | 
    
         
            +
                            )
         
     | 
| 
      
 362 
     | 
    
         
            +
                            assert (
         
     | 
| 
      
 363 
     | 
    
         
            +
                                page_table_1_flattened.shape[0] == forward_batch.seq_lens_sum
         
     | 
| 
      
 364 
     | 
    
         
            +
                            ), f"{page_table_1_flattened.shape[0] = } must be the same as {forward_batch.seq_lens_sum = }"
         
     | 
| 
      
 365 
     | 
    
         
            +
             
     | 
| 
      
 366 
     | 
    
         
            +
                            topk_indices_offset = torch.repeat_interleave(
         
     | 
| 
      
 367 
     | 
    
         
            +
                                cu_seqlens_k[:-1],
         
     | 
| 
      
 368 
     | 
    
         
            +
                                forward_batch.extend_seq_lens,
         
     | 
| 
      
 369 
     | 
    
         
            +
                            )
         
     | 
| 
       313 
370 
     | 
    
         
             
                    else:
         
     | 
| 
       314 
371 
     | 
    
         
             
                        assert False, f"Unsupported {forward_batch.forward_mode = }"
         
     | 
| 
       315 
372 
     | 
    
         | 
| 
         @@ -328,7 +385,9 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       328 
385 
     | 
    
         
             
                        max_seq_len_k=max_seqlen_k,
         
     | 
| 
       329 
386 
     | 
    
         
             
                        cu_seqlens_q=cu_seqlens_q,
         
     | 
| 
       330 
387 
     | 
    
         
             
                        cu_seqlens_k=cu_seqlens_k,
         
     | 
| 
      
 388 
     | 
    
         
            +
                        seq_lens_sum=forward_batch.seq_lens_sum,
         
     | 
| 
       331 
389 
     | 
    
         
             
                        page_table_1=page_table,
         
     | 
| 
      
 390 
     | 
    
         
            +
                        page_table_1_flattened=page_table_1_flattened,
         
     | 
| 
       332 
391 
     | 
    
         
             
                        flashmla_metadata=(
         
     | 
| 
       333 
392 
     | 
    
         
             
                            self._compute_flashmla_metadata(
         
     | 
| 
       334 
393 
     | 
    
         
             
                                cache_seqlens=nsa_cache_seqlens_int32,
         
     | 
| 
         @@ -344,6 +403,7 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       344 
403 
     | 
    
         
             
                        nsa_extend_seq_lens_list=extend_seq_lens_cpu,
         
     | 
| 
       345 
404 
     | 
    
         
             
                        real_page_table=self._transform_table_1_to_real(page_table),
         
     | 
| 
       346 
405 
     | 
    
         
             
                        nsa_max_seqlen_q=1,
         
     | 
| 
      
 406 
     | 
    
         
            +
                        topk_indices_offset=topk_indices_offset,
         
     | 
| 
       347 
407 
     | 
    
         
             
                    )
         
     | 
| 
       348 
408 
     | 
    
         | 
| 
       349 
409 
     | 
    
         
             
                    self.forward_metadata = metadata
         
     | 
| 
         @@ -396,6 +456,8 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       396 
456 
     | 
    
         
             
                    forward_mode: ForwardMode,
         
     | 
| 
       397 
457 
     | 
    
         
             
                    spec_info: Optional[SpecInput],
         
     | 
| 
       398 
458 
     | 
    
         
             
                ):
         
     | 
| 
      
 459 
     | 
    
         
            +
                    self.set_nsa_prefill_impl(forward_batch=None)
         
     | 
| 
      
 460 
     | 
    
         
            +
             
     | 
| 
       399 
461 
     | 
    
         
             
                    """Initialize forward metadata for capturing CUDA graph."""
         
     | 
| 
       400 
462 
     | 
    
         
             
                    if forward_mode.is_decode_or_idle():
         
     | 
| 
       401 
463 
     | 
    
         
             
                        # Normal Decode
         
     | 
| 
         @@ -586,6 +648,8 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       586 
648 
     | 
    
         
             
                    """Initialize forward metadata for replaying CUDA graph."""
         
     | 
| 
       587 
649 
     | 
    
         
             
                    assert seq_lens_cpu is not None
         
     | 
| 
       588 
650 
     | 
    
         | 
| 
      
 651 
     | 
    
         
            +
                    self.set_nsa_prefill_impl(forward_batch=None)
         
     | 
| 
      
 652 
     | 
    
         
            +
             
     | 
| 
       589 
653 
     | 
    
         
             
                    seq_lens = seq_lens[:bs]
         
     | 
| 
       590 
654 
     | 
    
         
             
                    seq_lens_cpu = seq_lens_cpu[:bs]
         
     | 
| 
       591 
655 
     | 
    
         
             
                    req_pool_indices = req_pool_indices[:bs]
         
     | 
| 
         @@ -780,17 +844,31 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       780 
844 
     | 
    
         
             
                        q_rope = q_all[:, :, layer.v_head_dim :]
         
     | 
| 
       781 
845 
     | 
    
         | 
| 
       782 
846 
     | 
    
         
             
                    # NOTE(dark): here, we use page size = 1
         
     | 
| 
       783 
     | 
    
         
            -
             
     | 
| 
      
 847 
     | 
    
         
            +
                    topk_transform_method = self.get_topk_transform_method()
         
     | 
| 
       784 
848 
     | 
    
         
             
                    if NSA_FUSE_TOPK:
         
     | 
| 
       785 
849 
     | 
    
         
             
                        page_table_1 = topk_indices
         
     | 
| 
       786 
850 
     | 
    
         
             
                    else:
         
     | 
| 
       787 
     | 
    
         
            -
                         
     | 
| 
       788 
     | 
    
         
            -
             
     | 
| 
       789 
     | 
    
         
            -
                             
     | 
| 
       790 
     | 
    
         
            -
                             
     | 
| 
       791 
     | 
    
         
            -
                             
     | 
| 
       792 
     | 
    
         
            -
             
     | 
| 
       793 
     | 
    
         
            -
             
     | 
| 
      
 851 
     | 
    
         
            +
                        if topk_transform_method == TopkTransformMethod.RAGGED:
         
     | 
| 
      
 852 
     | 
    
         
            +
                            topk_indices_offset = metadata.topk_indices_offset
         
     | 
| 
      
 853 
     | 
    
         
            +
                            assert topk_indices_offset is not None
         
     | 
| 
      
 854 
     | 
    
         
            +
                            mask = topk_indices != -1
         
     | 
| 
      
 855 
     | 
    
         
            +
                            topk_indices_offset = (
         
     | 
| 
      
 856 
     | 
    
         
            +
                                topk_indices_offset.unsqueeze(1)
         
     | 
| 
      
 857 
     | 
    
         
            +
                                if topk_indices_offset.ndim == 1
         
     | 
| 
      
 858 
     | 
    
         
            +
                                else topk_indices_offset
         
     | 
| 
      
 859 
     | 
    
         
            +
                            )
         
     | 
| 
      
 860 
     | 
    
         
            +
                            topk_indices = torch.where(
         
     | 
| 
      
 861 
     | 
    
         
            +
                                mask, topk_indices + topk_indices_offset, topk_indices
         
     | 
| 
      
 862 
     | 
    
         
            +
                            )
         
     | 
| 
      
 863 
     | 
    
         
            +
                        elif topk_transform_method == TopkTransformMethod.PAGED:
         
     | 
| 
      
 864 
     | 
    
         
            +
                            assert metadata.nsa_extend_seq_lens_list is not None
         
     | 
| 
      
 865 
     | 
    
         
            +
                            page_table_1 = transform_index_page_table_prefill(
         
     | 
| 
      
 866 
     | 
    
         
            +
                                page_table=metadata.page_table_1,
         
     | 
| 
      
 867 
     | 
    
         
            +
                                topk_indices=topk_indices,
         
     | 
| 
      
 868 
     | 
    
         
            +
                                extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
         
     | 
| 
      
 869 
     | 
    
         
            +
                                page_size=1,
         
     | 
| 
      
 870 
     | 
    
         
            +
                            )
         
     | 
| 
      
 871 
     | 
    
         
            +
             
     | 
| 
       794 
872 
     | 
    
         
             
                    if NSA_PREFILL_IMPL == "tilelang":
         
     | 
| 
       795 
873 
     | 
    
         
             
                        if q_rope is not None:
         
     | 
| 
       796 
874 
     | 
    
         
             
                            q_all = torch.cat([q_nope, q_rope], dim=-1)
         
     | 
| 
         @@ -804,6 +882,22 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       804 
882 
     | 
    
         
             
                    elif NSA_PREFILL_IMPL == "flashmla_sparse":
         
     | 
| 
       805 
883 
     | 
    
         
             
                        if q_rope is not None:
         
     | 
| 
       806 
884 
     | 
    
         
             
                            q_all = torch.cat([q_nope, q_rope], dim=-1)
         
     | 
| 
      
 885 
     | 
    
         
            +
             
     | 
| 
      
 886 
     | 
    
         
            +
                        # NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 has no effect here,
         
     | 
| 
      
 887 
     | 
    
         
            +
                        # because the flashmla_sparse kernel doesn't support fp8 compute
         
     | 
| 
      
 888 
     | 
    
         
            +
                        if topk_transform_method == TopkTransformMethod.RAGGED:
         
     | 
| 
      
 889 
     | 
    
         
            +
                            if any(forward_batch.extend_prefix_lens_cpu):
         
     | 
| 
      
 890 
     | 
    
         
            +
                                page_table_1_flattened = (
         
     | 
| 
      
 891 
     | 
    
         
            +
                                    self.forward_metadata.page_table_1_flattened
         
     | 
| 
      
 892 
     | 
    
         
            +
                                )
         
     | 
| 
      
 893 
     | 
    
         
            +
                                assert page_table_1_flattened is not None
         
     | 
| 
      
 894 
     | 
    
         
            +
                                kv_cache = dequantize_k_cache_paged(
         
     | 
| 
      
 895 
     | 
    
         
            +
                                    kv_cache, page_table_1_flattened
         
     | 
| 
      
 896 
     | 
    
         
            +
                                )
         
     | 
| 
      
 897 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 898 
     | 
    
         
            +
                                kv_cache = torch.cat([k, k_rope], dim=-1)
         
     | 
| 
      
 899 
     | 
    
         
            +
                            page_table_1 = topk_indices
         
     | 
| 
      
 900 
     | 
    
         
            +
             
     | 
| 
       807 
901 
     | 
    
         
             
                        return self._forward_flashmla_sparse(
         
     | 
| 
       808 
902 
     | 
    
         
             
                            q_all=q_all,
         
     | 
| 
       809 
903 
     | 
    
         
             
                            kv_cache=kv_cache,
         
     | 
| 
         @@ -1004,7 +1098,7 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       1004 
1098 
     | 
    
         
             
                    page_table_1: torch.Tensor,
         
     | 
| 
       1005 
1099 
     | 
    
         
             
                    sm_scale: float,
         
     | 
| 
       1006 
1100 
     | 
    
         
             
                ) -> torch.Tensor:
         
     | 
| 
       1007 
     | 
    
         
            -
                    from flash_mla import flash_mla_sparse_fwd
         
     | 
| 
      
 1101 
     | 
    
         
            +
                    from sgl_kernel.flash_mla import flash_mla_sparse_fwd
         
     | 
| 
       1008 
1102 
     | 
    
         | 
| 
       1009 
1103 
     | 
    
         
             
                    o, _, _ = flash_mla_sparse_fwd(
         
     | 
| 
       1010 
1104 
     | 
    
         
             
                        q=q_all,
         
     | 
| 
         @@ -1025,7 +1119,7 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       1025 
1119 
     | 
    
         
             
                    metadata: NSAMetadata,
         
     | 
| 
       1026 
1120 
     | 
    
         
             
                    page_table_1,
         
     | 
| 
       1027 
1121 
     | 
    
         
             
                ) -> torch.Tensor:
         
     | 
| 
       1028 
     | 
    
         
            -
                    from flash_mla import flash_mla_with_kvcache
         
     | 
| 
      
 1122 
     | 
    
         
            +
                    from sgl_kernel.flash_mla import flash_mla_with_kvcache
         
     | 
| 
       1029 
1123 
     | 
    
         | 
| 
       1030 
1124 
     | 
    
         
             
                    cache_seqlens = metadata.nsa_cache_seqlens_int32
         
     | 
| 
       1031 
1125 
     | 
    
         | 
| 
         @@ -1121,13 +1215,53 @@ class NativeSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       1121 
1215 
     | 
    
         
             
                    """Get the fill value for sequence length in CUDA graph."""
         
     | 
| 
       1122 
1216 
     | 
    
         
             
                    return 1
         
     | 
| 
       1123 
1217 
     | 
    
         | 
| 
      
 1218 
     | 
    
         
            +
                def set_nsa_prefill_impl(self, forward_batch: Optional[ForwardBatch] = None) -> str:
         
     | 
| 
      
 1219 
     | 
    
         
            +
                    from sglang.srt.utils import is_blackwell
         
     | 
| 
      
 1220 
     | 
    
         
            +
             
     | 
| 
      
 1221 
     | 
    
         
            +
                    global NSA_PREFILL_IMPL
         
     | 
| 
      
 1222 
     | 
    
         
            +
                    if self.enable_auto_select_prefill_impl:
         
     | 
| 
      
 1223 
     | 
    
         
            +
                        if self.nsa_kv_cache_store_fp8:
         
     | 
| 
      
 1224 
     | 
    
         
            +
                            if (
         
     | 
| 
      
 1225 
     | 
    
         
            +
                                is_blackwell()
         
     | 
| 
      
 1226 
     | 
    
         
            +
                                and forward_batch is not None
         
     | 
| 
      
 1227 
     | 
    
         
            +
                                and forward_batch.forward_mode == ForwardMode.EXTEND
         
     | 
| 
      
 1228 
     | 
    
         
            +
                            ):
         
     | 
| 
      
 1229 
     | 
    
         
            +
                                total_kv_tokens = forward_batch.seq_lens_sum
         
     | 
| 
      
 1230 
     | 
    
         
            +
                                total_q_tokens = forward_batch.extend_num_tokens
         
     | 
| 
      
 1231 
     | 
    
         
            +
                                # Heuristic based on benchmarking flashmla_kv vs flashmla_sparse + dequantize_k_cache_paged
         
     | 
| 
      
 1232 
     | 
    
         
            +
                                if total_kv_tokens < total_q_tokens * 512:
         
     | 
| 
      
 1233 
     | 
    
         
            +
                                    NSA_PREFILL_IMPL = "flashmla_sparse"
         
     | 
| 
      
 1234 
     | 
    
         
            +
                                    return
         
     | 
| 
      
 1235 
     | 
    
         
            +
                            NSA_PREFILL_IMPL = "flashmla_kv"
         
     | 
| 
      
 1236 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 1237 
     | 
    
         
            +
                            # bf16 kv cache
         
     | 
| 
      
 1238 
     | 
    
         
            +
                            NSA_PREFILL_IMPL = "flashmla_sparse"
         
     | 
| 
      
 1239 
     | 
    
         
            +
             
     | 
| 
      
 1240 
     | 
    
         
            +
                def get_topk_transform_method(self) -> TopkTransformMethod:
         
     | 
| 
      
 1241 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1242 
     | 
    
         
            +
                    NSA_FUSE_TOPK controls whether to fuse the topk transform into the topk kernel.
         
     | 
| 
      
 1243 
     | 
    
         
            +
                    This method is used to select the topk transform method which can be fused or unfused.
         
     | 
| 
      
 1244 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1245 
     | 
    
         
            +
                    if (
         
     | 
| 
      
 1246 
     | 
    
         
            +
                        # disable for MTP
         
     | 
| 
      
 1247 
     | 
    
         
            +
                        self.nsa_kv_cache_store_fp8
         
     | 
| 
      
 1248 
     | 
    
         
            +
                        and NSA_PREFILL_IMPL == "flashmla_sparse"
         
     | 
| 
      
 1249 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 1250 
     | 
    
         
            +
                        topk_transform_method = TopkTransformMethod.RAGGED
         
     | 
| 
      
 1251 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 1252 
     | 
    
         
            +
                        topk_transform_method = TopkTransformMethod.PAGED
         
     | 
| 
      
 1253 
     | 
    
         
            +
                    return topk_transform_method
         
     | 
| 
      
 1254 
     | 
    
         
            +
             
     | 
| 
       1124 
1255 
     | 
    
         
             
                def get_indexer_metadata(
         
     | 
| 
       1125 
1256 
     | 
    
         
             
                    self, layer_id: int, forward_batch: ForwardBatch
         
     | 
| 
       1126 
1257 
     | 
    
         
             
                ) -> NSAIndexerMetadata:
         
     | 
| 
       1127 
     | 
    
         
            -
                    return NSAIndexerMetadata( 
     | 
| 
      
 1258 
     | 
    
         
            +
                    return NSAIndexerMetadata(
         
     | 
| 
      
 1259 
     | 
    
         
            +
                        attn_metadata=self.forward_metadata,
         
     | 
| 
      
 1260 
     | 
    
         
            +
                        topk_transform_method=self.get_topk_transform_method(),
         
     | 
| 
      
 1261 
     | 
    
         
            +
                    )
         
     | 
| 
       1128 
1262 
     | 
    
         | 
| 
       1129 
1263 
     | 
    
         
             
                def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
         
     | 
| 
       1130 
     | 
    
         
            -
                    from flash_mla import get_mla_metadata
         
     | 
| 
      
 1264 
     | 
    
         
            +
                    from sgl_kernel.flash_mla import get_mla_metadata
         
     | 
| 
       1131 
1265 
     | 
    
         | 
| 
       1132 
1266 
     | 
    
         
             
                    flashmla_metadata, num_splits = get_mla_metadata(
         
     | 
| 
       1133 
1267 
     | 
    
         
             
                        cache_seqlens=cache_seqlens,
         
     | 
| 
         @@ -92,7 +92,10 @@ class TritonAttnBackend(AttentionBackend): 
     | 
|
| 
       92 
92 
     | 
    
         
             
                    self.num_kv_head = model_runner.model_config.get_num_kv_heads(
         
     | 
| 
       93 
93 
     | 
    
         
             
                        get_attention_tp_size()
         
     | 
| 
       94 
94 
     | 
    
         
             
                    )
         
     | 
| 
       95 
     | 
    
         
            -
                    if  
     | 
| 
      
 95 
     | 
    
         
            +
                    if (
         
     | 
| 
      
 96 
     | 
    
         
            +
                        model_runner.hybrid_gdn_config is not None
         
     | 
| 
      
 97 
     | 
    
         
            +
                        or model_runner.kimi_linear_config is not None
         
     | 
| 
      
 98 
     | 
    
         
            +
                    ):
         
     | 
| 
       96 
99 
     | 
    
         
             
                        # For hybrid linear models, layer_id = 0 may not be full attention
         
     | 
| 
       97 
100 
     | 
    
         
             
                        self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
         
     | 
| 
       98 
101 
     | 
    
         
             
                    else:
         
     | 
| 
         @@ -488,10 +488,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): 
     | 
|
| 
       488 
488 
     | 
    
         
             
                            forward_batch.req_pool_indices, : metadata.max_seq_len_k
         
     | 
| 
       489 
489 
     | 
    
         
             
                        ]
         
     | 
| 
       490 
490 
     | 
    
         | 
| 
       491 
     | 
    
         
            -
                        if (
         
     | 
| 
       492 
     | 
    
         
            -
                             
     | 
| 
       493 
     | 
    
         
            -
             
     | 
| 
       494 
     | 
    
         
            -
                        ):
         
     | 
| 
      
 491 
     | 
    
         
            +
                        if any(
         
     | 
| 
      
 492 
     | 
    
         
            +
                            forward_batch.extend_prefix_lens_cpu
         
     | 
| 
      
 493 
     | 
    
         
            +
                        ) or forward_batch.forward_mode.is_draft_extend(include_v2=True):
         
     | 
| 
       495 
494 
     | 
    
         
             
                            extend_seq_lens = forward_batch.extend_seq_lens
         
     | 
| 
       496 
495 
     | 
    
         
             
                            metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
         
     | 
| 
       497 
496 
     | 
    
         
             
                            metadata.cu_seqlens_q = torch.nn.functional.pad(
         
     | 
| 
         @@ -529,6 +528,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): 
     | 
|
| 
       529 
528 
     | 
    
         
             
                            layer, cache_loc, k, v, layer.k_scale, layer.v_scale
         
     | 
| 
       530 
529 
     | 
    
         
             
                        )
         
     | 
| 
       531 
530 
     | 
    
         | 
| 
      
 531 
     | 
    
         
            +
                    if self.data_type == torch.float8_e4m3fn:
         
     | 
| 
      
 532 
     | 
    
         
            +
                        q = q.to(torch.float8_e4m3fn)
         
     | 
| 
       532 
533 
     | 
    
         
             
                    q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
         
     | 
| 
       533 
534 
     | 
    
         
             
                    k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
         
     | 
| 
       534 
535 
     | 
    
         
             
                    # shape conversion:
         
     | 
| 
         @@ -567,6 +568,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): 
     | 
|
| 
       567 
568 
     | 
    
         
             
                        window_left=layer.sliding_window_size,
         
     | 
| 
       568 
569 
     | 
    
         
             
                        # TODO: add attention_sink operation or nvfp4 scale factor if needed
         
     | 
| 
       569 
570 
     | 
    
         
             
                        sinks=attention_sink,
         
     | 
| 
      
 571 
     | 
    
         
            +
                        out_dtype=self.q_data_type,  # model_runner.dtype
         
     | 
| 
       570 
572 
     | 
    
         
             
                    )
         
     | 
| 
       571 
573 
     | 
    
         | 
| 
       572 
574 
     | 
    
         
             
                    return o.view(-1, layer.tp_q_head_num * layer.head_dim)
         
     | 
| 
         @@ -586,6 +588,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): 
     | 
|
| 
       586 
588 
     | 
    
         
             
                        forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
       587 
589 
     | 
    
         
             
                            layer, cache_loc, k, v, layer.k_scale, layer.v_scale
         
     | 
| 
       588 
590 
     | 
    
         
             
                        )
         
     | 
| 
      
 591 
     | 
    
         
            +
             
     | 
| 
      
 592 
     | 
    
         
            +
                    if self.data_type == torch.float8_e4m3fn:
         
     | 
| 
      
 593 
     | 
    
         
            +
                        q = q.to(torch.float8_e4m3fn)
         
     | 
| 
       589 
594 
     | 
    
         
             
                    q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
         
     | 
| 
       590 
595 
     | 
    
         
             
                    # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
         
     | 
| 
       591 
596 
     | 
    
         
             
                    k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
         
     | 
| 
         @@ -625,6 +630,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): 
     | 
|
| 
       625 
630 
     | 
    
         
             
                        window_left=layer.sliding_window_size,
         
     | 
| 
       626 
631 
     | 
    
         
             
                        # TODO: add attention_sink operation or nvfp4 scale factor if needed
         
     | 
| 
       627 
632 
     | 
    
         
             
                        sinks=attention_sink,
         
     | 
| 
      
 633 
     | 
    
         
            +
                        out_dtype=self.q_data_type,  # model_runner.dtype
         
     | 
| 
       628 
634 
     | 
    
         
             
                    )
         
     | 
| 
       629 
635 
     | 
    
         | 
| 
       630 
636 
     | 
    
         
             
                    return o.view(-1, layer.tp_q_head_num * layer.head_dim)
         
     | 
| 
         @@ -944,8 +944,16 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): 
     | 
|
| 
       944 
944 
     | 
    
         
             
                                metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num
         
     | 
| 
       945 
945 
     | 
    
         
             
                            )
         
     | 
| 
       946 
946 
     | 
    
         
             
                        else:
         
     | 
| 
       947 
     | 
    
         
            -
                             
     | 
| 
       948 
     | 
    
         
            -
                             
     | 
| 
      
 947 
     | 
    
         
            +
                            # forward_batch.seq_lens is the seq_lens of the prev_context + verified tokens.
         
     | 
| 
      
 948 
     | 
    
         
            +
                            # To account for pad_draft_extend_query, we need seq_lens = prev_context + max_draft_tokens.
         
     | 
| 
      
 949 
     | 
    
         
            +
                            # This will ensure queries align with kvs correctly when calling
         
     | 
| 
      
 950 
     | 
    
         
            +
                            # flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla.
         
     | 
| 
      
 951 
     | 
    
         
            +
                            seq_lens = (
         
     | 
| 
      
 952 
     | 
    
         
            +
                                forward_batch.seq_lens
         
     | 
| 
      
 953 
     | 
    
         
            +
                                - metadata.seq_lens_q
         
     | 
| 
      
 954 
     | 
    
         
            +
                                + metadata.max_seq_len_q
         
     | 
| 
      
 955 
     | 
    
         
            +
                            ).to(torch.int32)
         
     | 
| 
      
 956 
     | 
    
         
            +
                            max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q
         
     | 
| 
       949 
957 
     | 
    
         
             
                            # Check if we're in CUDA graph mode (buffers are pre-allocated)
         
     | 
| 
       950 
958 
     | 
    
         
             
                            if self.padded_q_buffer is not None:
         
     | 
| 
       951 
959 
     | 
    
         
             
                                # Use pre-allocated buffer for CUDA graph compatibility
         
     | 
| 
         @@ -15,7 +15,7 @@ 
     | 
|
| 
       15 
15 
     | 
    
         
             
            from dataclasses import dataclass
         
     | 
| 
       16 
16 
     | 
    
         
             
            from enum import Enum, auto
         
     | 
| 
       17 
17 
     | 
    
         
             
            from functools import partial
         
     | 
| 
       18 
     | 
    
         
            -
            from typing import Dict, Optional
         
     | 
| 
      
 18 
     | 
    
         
            +
            from typing import Dict, List, Optional
         
     | 
| 
       19 
19 
     | 
    
         | 
| 
       20 
20 
     | 
    
         
             
            import torch
         
     | 
| 
       21 
21 
     | 
    
         | 
| 
         @@ -216,6 +216,28 @@ class LayerCommunicator: 
     | 
|
| 
       216 
216 
     | 
    
         
             
                        get_global_server_args().speculative_algorithm
         
     | 
| 
       217 
217 
     | 
    
         
             
                    )
         
     | 
| 
       218 
218 
     | 
    
         | 
| 
      
 219 
     | 
    
         
            +
                def prepare_attn_and_capture_last_layer_outputs(
         
     | 
| 
      
 220 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 221 
     | 
    
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 
      
 222 
     | 
    
         
            +
                    residual: torch.Tensor,
         
     | 
| 
      
 223 
     | 
    
         
            +
                    forward_batch: ForwardBatch,
         
     | 
| 
      
 224 
     | 
    
         
            +
                    captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
         
     | 
| 
      
 225 
     | 
    
         
            +
                ):
         
     | 
| 
      
 226 
     | 
    
         
            +
                    hidden_states, residual = self.prepare_attn(
         
     | 
| 
      
 227 
     | 
    
         
            +
                        hidden_states, residual, forward_batch
         
     | 
| 
      
 228 
     | 
    
         
            +
                    )
         
     | 
| 
      
 229 
     | 
    
         
            +
                    if captured_last_layer_outputs is not None:
         
     | 
| 
      
 230 
     | 
    
         
            +
                        gathered_last_layer_output = self._communicate_simple_fn(
         
     | 
| 
      
 231 
     | 
    
         
            +
                            hidden_states=residual,
         
     | 
| 
      
 232 
     | 
    
         
            +
                            forward_batch=forward_batch,
         
     | 
| 
      
 233 
     | 
    
         
            +
                            context=self._context,
         
     | 
| 
      
 234 
     | 
    
         
            +
                        )
         
     | 
| 
      
 235 
     | 
    
         
            +
                        if gathered_last_layer_output is residual:
         
     | 
| 
      
 236 
     | 
    
         
            +
                            # Clone to avoid modifying the original residual by Custom RMSNorm inplace operation
         
     | 
| 
      
 237 
     | 
    
         
            +
                            gathered_last_layer_output = residual.clone()
         
     | 
| 
      
 238 
     | 
    
         
            +
                        captured_last_layer_outputs.append(gathered_last_layer_output)
         
     | 
| 
      
 239 
     | 
    
         
            +
                    return hidden_states, residual
         
     | 
| 
      
 240 
     | 
    
         
            +
             
     | 
| 
       219 
241 
     | 
    
         
             
                def prepare_attn(
         
     | 
| 
       220 
242 
     | 
    
         
             
                    self,
         
     | 
| 
       221 
243 
     | 
    
         
             
                    hidden_states: torch.Tensor,
         
     | 
    
        sglang/srt/layers/layernorm.py
    CHANGED
    
    | 
         @@ -20,7 +20,12 @@ import torch 
     | 
|
| 
       20 
20 
     | 
    
         
             
            import torch.nn as nn
         
     | 
| 
       21 
21 
     | 
    
         
             
            from packaging.version import Version
         
     | 
| 
       22 
22 
     | 
    
         | 
| 
      
 23 
     | 
    
         
            +
            from sglang.srt.batch_invariant_ops import (
         
     | 
| 
      
 24 
     | 
    
         
            +
                is_batch_invariant_mode_enabled,
         
     | 
| 
      
 25 
     | 
    
         
            +
                rms_norm_batch_invariant,
         
     | 
| 
      
 26 
     | 
    
         
            +
            )
         
     | 
| 
       23 
27 
     | 
    
         
             
            from sglang.srt.custom_op import CustomOp
         
     | 
| 
      
 28 
     | 
    
         
            +
            from sglang.srt.server_args import get_global_server_args
         
     | 
| 
       24 
29 
     | 
    
         
             
            from sglang.srt.utils import (
         
     | 
| 
       25 
30 
     | 
    
         
             
                cpu_has_amx_support,
         
     | 
| 
       26 
31 
     | 
    
         
             
                get_bool_env_var,
         
     | 
| 
         @@ -90,8 +95,6 @@ class RMSNorm(CustomOp): 
     | 
|
| 
       90 
95 
     | 
    
         
             
                    )
         
     | 
| 
       91 
96 
     | 
    
         
             
                    if _use_aiter:
         
     | 
| 
       92 
97 
     | 
    
         
             
                        self._forward_method = self.forward_aiter
         
     | 
| 
       93 
     | 
    
         
            -
                    if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
         
     | 
| 
       94 
     | 
    
         
            -
                        self._forward_method = self.forward_native
         
     | 
| 
       95 
98 
     | 
    
         | 
| 
       96 
99 
     | 
    
         
             
                def forward_cuda(
         
     | 
| 
       97 
100 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -100,6 +103,17 @@ class RMSNorm(CustomOp): 
     | 
|
| 
       100 
103 
     | 
    
         
             
                ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         
     | 
| 
       101 
104 
     | 
    
         
             
                    if self.variance_size_override is not None:
         
     | 
| 
       102 
105 
     | 
    
         
             
                        return self.forward_native(x, residual)
         
     | 
| 
      
 106 
     | 
    
         
            +
                    if is_batch_invariant_mode_enabled():
         
     | 
| 
      
 107 
     | 
    
         
            +
                        if (
         
     | 
| 
      
 108 
     | 
    
         
            +
                            residual is not None
         
     | 
| 
      
 109 
     | 
    
         
            +
                            or get_global_server_args().rl_on_policy_target == "fsdp"
         
     | 
| 
      
 110 
     | 
    
         
            +
                        ):
         
     | 
| 
      
 111 
     | 
    
         
            +
                            return self.forward_native(x, residual)
         
     | 
| 
      
 112 
     | 
    
         
            +
                        return rms_norm_batch_invariant(
         
     | 
| 
      
 113 
     | 
    
         
            +
                            x,
         
     | 
| 
      
 114 
     | 
    
         
            +
                            self.weight.data,
         
     | 
| 
      
 115 
     | 
    
         
            +
                            self.variance_epsilon,
         
     | 
| 
      
 116 
     | 
    
         
            +
                        )
         
     | 
| 
       103 
117 
     | 
    
         
             
                    if residual is not None:
         
     | 
| 
       104 
118 
     | 
    
         
             
                        fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
         
     | 
| 
       105 
119 
     | 
    
         
             
                        return x, residual
         
     | 
| 
         @@ -38,7 +38,6 @@ from sglang.srt.layers.dp_attention import ( 
     | 
|
| 
       38 
38 
     | 
    
         
             
                get_dp_device,
         
     | 
| 
       39 
39 
     | 
    
         
             
                get_dp_dtype,
         
     | 
| 
       40 
40 
     | 
    
         
             
                get_dp_hidden_size,
         
     | 
| 
       41 
     | 
    
         
            -
                get_local_attention_dp_size,
         
     | 
| 
       42 
41 
     | 
    
         
             
            )
         
     | 
| 
       43 
42 
     | 
    
         
             
            from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
         
     | 
| 
       44 
43 
     | 
    
         
             
            from sglang.srt.model_executor.forward_batch_info import (
         
     | 
| 
         @@ -47,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import ( 
     | 
|
| 
       47 
46 
     | 
    
         
             
                ForwardMode,
         
     | 
| 
       48 
47 
     | 
    
         
             
            )
         
     | 
| 
       49 
48 
     | 
    
         
             
            from sglang.srt.server_args import get_global_server_args
         
     | 
| 
       50 
     | 
    
         
            -
            from sglang.srt.utils import  
     | 
| 
      
 49 
     | 
    
         
            +
            from sglang.srt.utils import is_npu, use_intel_amx_backend
         
     | 
| 
       51 
50 
     | 
    
         | 
| 
       52 
51 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
       53 
52 
     | 
    
         | 
| 
         @@ -135,10 +134,7 @@ class LogitsMetadata: 
     | 
|
| 
       135 
134 
     | 
    
         
             
                @classmethod
         
     | 
| 
       136 
135 
     | 
    
         
             
                def from_forward_batch(cls, forward_batch: ForwardBatch):
         
     | 
| 
       137 
136 
     | 
    
         
             
                    if (
         
     | 
| 
       138 
     | 
    
         
            -
                        (
         
     | 
| 
       139 
     | 
    
         
            -
                            forward_batch.forward_mode.is_extend()
         
     | 
| 
       140 
     | 
    
         
            -
                            or forward_batch.forward_mode.is_split_prefill()
         
     | 
| 
       141 
     | 
    
         
            -
                        )
         
     | 
| 
      
 137 
     | 
    
         
            +
                        forward_batch.forward_mode.is_extend()
         
     | 
| 
       142 
138 
     | 
    
         
             
                        and forward_batch.return_logprob
         
     | 
| 
       143 
139 
     | 
    
         
             
                        and not forward_batch.forward_mode.is_target_verify()
         
     | 
| 
       144 
140 
     | 
    
         
             
                    ):
         
     | 
| 
         @@ -252,10 +248,6 @@ class LogitsProcessor(nn.Module): 
     | 
|
| 
       252 
248 
     | 
    
         
             
                    ):
         
     | 
| 
       253 
249 
     | 
    
         
             
                        self.final_logit_softcapping = None
         
     | 
| 
       254 
250 
     | 
    
         | 
| 
       255 
     | 
    
         
            -
                    self.debug_tensor_dump_output_folder = (
         
     | 
| 
       256 
     | 
    
         
            -
                        get_global_server_args().debug_tensor_dump_output_folder
         
     | 
| 
       257 
     | 
    
         
            -
                    )
         
     | 
| 
       258 
     | 
    
         
            -
             
     | 
| 
       259 
251 
     | 
    
         
             
                def compute_logprobs_for_multi_item_scoring(
         
     | 
| 
       260 
252 
     | 
    
         
             
                    self,
         
     | 
| 
       261 
253 
     | 
    
         
             
                    input_ids,
         
     | 
| 
         @@ -389,8 +381,8 @@ class LogitsProcessor(nn.Module): 
     | 
|
| 
       389 
381 
     | 
    
         
             
                        input_logprob_indices = None
         
     | 
| 
       390 
382 
     | 
    
         
             
                    elif (
         
     | 
| 
       391 
383 
     | 
    
         
             
                        logits_metadata.forward_mode.is_extend()
         
     | 
| 
       392 
     | 
    
         
            -
                         
     | 
| 
       393 
     | 
    
         
            -
                    ) 
     | 
| 
      
 384 
     | 
    
         
            +
                        and not logits_metadata.extend_return_logprob
         
     | 
| 
      
 385 
     | 
    
         
            +
                    ):
         
     | 
| 
       394 
386 
     | 
    
         
             
                        # Prefill without input logprobs.
         
     | 
| 
       395 
387 
     | 
    
         
             
                        if logits_metadata.padded_static_len < 0:
         
     | 
| 
       396 
388 
     | 
    
         
             
                            last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
         
     | 
| 
         @@ -463,14 +455,6 @@ class LogitsProcessor(nn.Module): 
     | 
|
| 
       463 
455 
     | 
    
         
             
                        logits[sample_indices] if sample_indices is not None else logits
         
     | 
| 
       464 
456 
     | 
    
         
             
                    )
         
     | 
| 
       465 
457 
     | 
    
         | 
| 
       466 
     | 
    
         
            -
                    if self.debug_tensor_dump_output_folder:
         
     | 
| 
       467 
     | 
    
         
            -
                        assert (
         
     | 
| 
       468 
     | 
    
         
            -
                            not self.do_tensor_parallel_all_gather
         
     | 
| 
       469 
     | 
    
         
            -
                            or get_local_attention_dp_size() == 1
         
     | 
| 
       470 
     | 
    
         
            -
                        ), "dp attention + sharded lm_head doesn't support full logits"
         
     | 
| 
       471 
     | 
    
         
            -
                        full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
         
     | 
| 
       472 
     | 
    
         
            -
                        dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
         
     | 
| 
       473 
     | 
    
         
            -
             
     | 
| 
       474 
458 
     | 
    
         
             
                    hidden_states_to_store: Optional[torch.Tensor] = None
         
     | 
| 
       475 
459 
     | 
    
         
             
                    if logits_metadata.capture_hidden_mode.need_capture():
         
     | 
| 
       476 
460 
     | 
    
         
             
                        if logits_metadata.capture_hidden_mode.is_full():
         
     | 
| 
         @@ -131,23 +131,6 @@ class DeepEPMoE(FusedMoE): 
     | 
|
| 
       131 
131 
     | 
    
         
             
                        )
         
     | 
| 
       132 
132 
     | 
    
         
             
                        # the last one is invalid rank_id
         
     | 
| 
       133 
133 
     | 
    
         
             
                        self.expert_mask[:-1] = 1
         
     | 
| 
       134 
     | 
    
         
            -
                    elif not _is_npu:
         
     | 
| 
       135 
     | 
    
         
            -
                        self.w13_weight_fp8 = (
         
     | 
| 
       136 
     | 
    
         
            -
                            self.w13_weight,
         
     | 
| 
       137 
     | 
    
         
            -
                            (
         
     | 
| 
       138 
     | 
    
         
            -
                                self.w13_weight_scale_inv
         
     | 
| 
       139 
     | 
    
         
            -
                                if self.use_block_quant or self.use_w4afp8
         
     | 
| 
       140 
     | 
    
         
            -
                                else self.w13_weight_scale
         
     | 
| 
       141 
     | 
    
         
            -
                            ),
         
     | 
| 
       142 
     | 
    
         
            -
                        )
         
     | 
| 
       143 
     | 
    
         
            -
                        self.w2_weight_fp8 = (
         
     | 
| 
       144 
     | 
    
         
            -
                            self.w2_weight,
         
     | 
| 
       145 
     | 
    
         
            -
                            (
         
     | 
| 
       146 
     | 
    
         
            -
                                self.w2_weight_scale_inv
         
     | 
| 
       147 
     | 
    
         
            -
                                if self.use_block_quant or self.use_w4afp8
         
     | 
| 
       148 
     | 
    
         
            -
                                else self.w2_weight_scale
         
     | 
| 
       149 
     | 
    
         
            -
                            ),
         
     | 
| 
       150 
     | 
    
         
            -
                        )
         
     | 
| 
       151 
134 
     | 
    
         | 
| 
       152 
135 
     | 
    
         
             
                def forward(
         
     | 
| 
       153 
136 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -235,7 +218,6 @@ class DeepEPMoE(FusedMoE): 
     | 
|
| 
       235 
218 
     | 
    
         
             
                        hidden_states=output,
         
     | 
| 
       236 
219 
     | 
    
         
             
                        topk_ids=dispatch_output.topk_ids,
         
     | 
| 
       237 
220 
     | 
    
         
             
                        topk_weights=dispatch_output.topk_weights,
         
     | 
| 
       238 
     | 
    
         
            -
                        overlap_args=down_gemm_overlap_args,
         
     | 
| 
       239 
221 
     | 
    
         
             
                    )
         
     | 
| 
       240 
222 
     | 
    
         | 
| 
       241 
223 
     | 
    
         
             
                def combine(
         
     |