sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +18 -3
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +25 -2
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -5
 - sglang/srt/entrypoints/engine.py +13 -5
 - sglang/srt/entrypoints/http_server.py +22 -3
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +7 -0
 - sglang/srt/eplb/expert_distribution.py +34 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +7 -6
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
 - sglang/srt/layers/communicator.py +23 -1
 - sglang/srt/layers/layernorm.py +16 -2
 - sglang/srt/layers/logits_processor.py +4 -20
 - sglang/srt/layers/moe/ep_moe/layer.py +0 -18
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
 - sglang/srt/layers/moe/topk.py +31 -6
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +9 -78
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/rotary_embedding.py +117 -45
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +26 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +164 -129
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +154 -59
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/memory_pool.py +171 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +11 -11
 - sglang/srt/model_executor/model_runner.py +76 -21
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +149 -34
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +0 -1
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +1 -1
 - sglang/srt/models/qwen2.py +1 -1
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +1 -1
 - sglang/srt/models/qwen3_moe.py +16 -8
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +103 -22
 - sglang/srt/single_batch_overlap.py +4 -1
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +55 -32
 - sglang/srt/utils/hf_transformers_utils.py +38 -16
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
    
        sglang/srt/models/deepseek_v2.py
    CHANGED
    
    | 
         @@ -21,7 +21,7 @@ import concurrent.futures 
     | 
|
| 
       21 
21 
     | 
    
         
             
            import logging
         
     | 
| 
       22 
22 
     | 
    
         
             
            import os
         
     | 
| 
       23 
23 
     | 
    
         
             
            from enum import IntEnum, auto
         
     | 
| 
       24 
     | 
    
         
            -
            from typing import Any, Dict, Iterable, Optional, Tuple, Union
         
     | 
| 
      
 24 
     | 
    
         
            +
            from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
         
     | 
| 
       25 
25 
     | 
    
         | 
| 
       26 
26 
     | 
    
         
             
            import torch
         
     | 
| 
       27 
27 
     | 
    
         
             
            import torch.nn.functional as F
         
     | 
| 
         @@ -131,13 +131,11 @@ from sglang.srt.utils import ( 
     | 
|
| 
       131 
131 
     | 
    
         
             
                get_int_env_var,
         
     | 
| 
       132 
132 
     | 
    
         
             
                is_cpu,
         
     | 
| 
       133 
133 
     | 
    
         
             
                is_cuda,
         
     | 
| 
       134 
     | 
    
         
            -
                is_flashinfer_available,
         
     | 
| 
       135 
134 
     | 
    
         
             
                is_gfx95_supported,
         
     | 
| 
       136 
135 
     | 
    
         
             
                is_hip,
         
     | 
| 
       137 
136 
     | 
    
         
             
                is_non_idle_and_non_empty,
         
     | 
| 
       138 
137 
     | 
    
         
             
                is_npu,
         
     | 
| 
       139 
138 
     | 
    
         
             
                is_nvidia_cublas_cu12_version_ge_12_9,
         
     | 
| 
       140 
     | 
    
         
            -
                is_sm100_supported,
         
     | 
| 
       141 
139 
     | 
    
         
             
                log_info_on_rank0,
         
     | 
| 
       142 
140 
     | 
    
         
             
                make_layers,
         
     | 
| 
       143 
141 
     | 
    
         
             
                use_intel_amx_backend,
         
     | 
| 
         @@ -197,8 +195,6 @@ elif _is_npu: 
     | 
|
| 
       197 
195 
     | 
    
         
             
            else:
         
     | 
| 
       198 
196 
     | 
    
         
             
                pass
         
     | 
| 
       199 
197 
     | 
    
         | 
| 
       200 
     | 
    
         
            -
            _is_flashinfer_available = is_flashinfer_available()
         
     | 
| 
       201 
     | 
    
         
            -
            _is_sm100_supported = is_cuda() and is_sm100_supported()
         
     | 
| 
       202 
198 
     | 
    
         
             
            _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
         
     | 
| 
       203 
199 
     | 
    
         | 
| 
       204 
200 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
         @@ -228,6 +224,17 @@ def add_forward_absorb_core_attention_backend(backend_name): 
     | 
|
| 
       228 
224 
     | 
    
         
             
                    logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
         
     | 
| 
       229 
225 
     | 
    
         | 
| 
       230 
226 
     | 
    
         | 
| 
      
 227 
     | 
    
         
            +
            def is_nsa_indexer_wk_and_weights_proj_fused(config, quant_config):
         
     | 
| 
      
 228 
     | 
    
         
            +
                """
         
     | 
| 
      
 229 
     | 
    
         
            +
                NSA Indexer wk and weights_proj can be fused in FP4 model because they are both in BF16
         
     | 
| 
      
 230 
     | 
    
         
            +
                """
         
     | 
| 
      
 231 
     | 
    
         
            +
                return (
         
     | 
| 
      
 232 
     | 
    
         
            +
                    is_deepseek_nsa(config)
         
     | 
| 
      
 233 
     | 
    
         
            +
                    and quant_config is not None
         
     | 
| 
      
 234 
     | 
    
         
            +
                    and quant_config.get_name() == "modelopt_fp4"
         
     | 
| 
      
 235 
     | 
    
         
            +
                )
         
     | 
| 
      
 236 
     | 
    
         
            +
             
     | 
| 
      
 237 
     | 
    
         
            +
             
     | 
| 
       231 
238 
     | 
    
         
             
            class AttnForwardMethod(IntEnum):
         
     | 
| 
       232 
239 
     | 
    
         
             
                # Use multi-head attention
         
     | 
| 
       233 
240 
     | 
    
         
             
                MHA = auto()
         
     | 
| 
         @@ -283,6 +290,7 @@ def handle_attention_ascend(attn, forward_batch): 
     | 
|
| 
       283 
290 
     | 
    
         
             
                    forward_batch.forward_mode.is_extend()
         
     | 
| 
       284 
291 
     | 
    
         
             
                    and not forward_batch.forward_mode.is_target_verify()
         
     | 
| 
       285 
292 
     | 
    
         
             
                    and not forward_batch.forward_mode.is_draft_extend()
         
     | 
| 
      
 293 
     | 
    
         
            +
                    and not forward_batch.forward_mode.is_draft_extend_v2()
         
     | 
| 
       286 
294 
     | 
    
         
             
                ):
         
     | 
| 
       287 
295 
     | 
    
         
             
                    if hasattr(attn, "indexer"):
         
     | 
| 
       288 
296 
     | 
    
         
             
                        return AttnForwardMethod.NPU_MLA_SPARSE
         
     | 
| 
         @@ -519,6 +527,9 @@ class MoEGate(nn.Module): 
     | 
|
| 
       519 
527 
     | 
    
         
             
                            True,  # is_vnni
         
     | 
| 
       520 
528 
     | 
    
         
             
                        )
         
     | 
| 
       521 
529 
     | 
    
         | 
| 
      
 530 
     | 
    
         
            +
                    if get_global_server_args().enable_deterministic_inference:
         
     | 
| 
      
 531 
     | 
    
         
            +
                        return F.linear(hidden_states, self.weight, None)
         
     | 
| 
      
 532 
     | 
    
         
            +
             
     | 
| 
       522 
533 
     | 
    
         
             
                    # NOTE: For some unknown reason, router_gemm seems degrade accept length.
         
     | 
| 
       523 
534 
     | 
    
         
             
                    if (
         
     | 
| 
       524 
535 
     | 
    
         
             
                        _is_cuda
         
     | 
| 
         @@ -1064,6 +1075,7 @@ class DeepseekV2AttentionMLA(nn.Module): 
     | 
|
| 
       1064 
1075 
     | 
    
         
             
                    layer_id: int = None,
         
     | 
| 
       1065 
1076 
     | 
    
         
             
                    prefix: str = "",
         
     | 
| 
       1066 
1077 
     | 
    
         
             
                    alt_stream: Optional[torch.cuda.Stream] = None,
         
     | 
| 
      
 1078 
     | 
    
         
            +
                    skip_rope: bool = False,
         
     | 
| 
       1067 
1079 
     | 
    
         
             
                ) -> None:
         
     | 
| 
       1068 
1080 
     | 
    
         
             
                    super().__init__()
         
     | 
| 
       1069 
1081 
     | 
    
         
             
                    self.layer_id = layer_id
         
     | 
| 
         @@ -1144,6 +1156,9 @@ class DeepseekV2AttentionMLA(nn.Module): 
     | 
|
| 
       1144 
1156 
     | 
    
         
             
                            quant_config=quant_config,
         
     | 
| 
       1145 
1157 
     | 
    
         
             
                            layer_id=layer_id,
         
     | 
| 
       1146 
1158 
     | 
    
         
             
                            alt_stream=alt_stream,
         
     | 
| 
      
 1159 
     | 
    
         
            +
                            fuse_wk_and_weights_proj=is_nsa_indexer_wk_and_weights_proj_fused(
         
     | 
| 
      
 1160 
     | 
    
         
            +
                                config, quant_config
         
     | 
| 
      
 1161 
     | 
    
         
            +
                            ),
         
     | 
| 
       1147 
1162 
     | 
    
         
             
                        )
         
     | 
| 
       1148 
1163 
     | 
    
         | 
| 
       1149 
1164 
     | 
    
         
             
                    self.kv_b_proj = ColumnParallelLinear(
         
     | 
| 
         @@ -1168,23 +1183,26 @@ class DeepseekV2AttentionMLA(nn.Module): 
     | 
|
| 
       1168 
1183 
     | 
    
         
             
                    )
         
     | 
| 
       1169 
1184 
     | 
    
         
             
                    self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
         
     | 
| 
       1170 
1185 
     | 
    
         | 
| 
       1171 
     | 
    
         
            -
                     
     | 
| 
       1172 
     | 
    
         
            -
                         
     | 
| 
       1173 
     | 
    
         
            -
             
     | 
| 
       1174 
     | 
    
         
            -
             
     | 
| 
       1175 
     | 
    
         
            -
             
     | 
| 
       1176 
     | 
    
         
            -
             
     | 
| 
       1177 
     | 
    
         
            -
             
     | 
| 
       1178 
     | 
    
         
            -
             
     | 
| 
       1179 
     | 
    
         
            -
             
     | 
| 
      
 1186 
     | 
    
         
            +
                    if not skip_rope:
         
     | 
| 
      
 1187 
     | 
    
         
            +
                        self.rotary_emb = get_rope_wrapper(
         
     | 
| 
      
 1188 
     | 
    
         
            +
                            qk_rope_head_dim,
         
     | 
| 
      
 1189 
     | 
    
         
            +
                            rotary_dim=qk_rope_head_dim,
         
     | 
| 
      
 1190 
     | 
    
         
            +
                            max_position=max_position_embeddings,
         
     | 
| 
      
 1191 
     | 
    
         
            +
                            base=rope_theta,
         
     | 
| 
      
 1192 
     | 
    
         
            +
                            rope_scaling=rope_scaling,
         
     | 
| 
      
 1193 
     | 
    
         
            +
                            is_neox_style=False,
         
     | 
| 
      
 1194 
     | 
    
         
            +
                            device=get_global_server_args().device,
         
     | 
| 
      
 1195 
     | 
    
         
            +
                        )
         
     | 
| 
       1180 
1196 
     | 
    
         | 
| 
       1181 
     | 
    
         
            -
             
     | 
| 
       1182 
     | 
    
         
            -
             
     | 
| 
       1183 
     | 
    
         
            -
             
     | 
| 
       1184 
     | 
    
         
            -
             
     | 
| 
       1185 
     | 
    
         
            -
             
     | 
| 
      
 1197 
     | 
    
         
            +
                        if rope_scaling:
         
     | 
| 
      
 1198 
     | 
    
         
            +
                            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
         
     | 
| 
      
 1199 
     | 
    
         
            +
                            scaling_factor = rope_scaling["factor"]
         
     | 
| 
      
 1200 
     | 
    
         
            +
                            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
         
     | 
| 
      
 1201 
     | 
    
         
            +
                            self.scaling = self.scaling * mscale * mscale
         
     | 
| 
      
 1202 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 1203 
     | 
    
         
            +
                            self.rotary_emb.forward = self.rotary_emb.forward_native
         
     | 
| 
       1186 
1204 
     | 
    
         
             
                    else:
         
     | 
| 
       1187 
     | 
    
         
            -
                        self.rotary_emb 
     | 
| 
      
 1205 
     | 
    
         
            +
                        self.rotary_emb = None
         
     | 
| 
       1188 
1206 
     | 
    
         | 
| 
       1189 
1207 
     | 
    
         
             
                    self.attn_mqa = RadixAttention(
         
     | 
| 
       1190 
1208 
     | 
    
         
             
                        self.num_local_heads,
         
     | 
| 
         @@ -1260,7 +1278,7 @@ class DeepseekV2AttentionMLA(nn.Module): 
     | 
|
| 
       1260 
1278 
     | 
    
         
             
                        and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
         
     | 
| 
       1261 
1279 
     | 
    
         
             
                        and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
         
     | 
| 
       1262 
1280 
     | 
    
         
             
                        and _is_cuda
         
     | 
| 
       1263 
     | 
    
         
            -
                        and _device_sm  
     | 
| 
      
 1281 
     | 
    
         
            +
                        and 90 <= _device_sm < 120
         
     | 
| 
       1264 
1282 
     | 
    
         
             
                    )
         
     | 
| 
       1265 
1283 
     | 
    
         | 
| 
       1266 
1284 
     | 
    
         
             
                    self.qkv_proj_with_rope_is_int8 = (
         
     | 
| 
         @@ -1473,7 +1491,8 @@ class DeepseekV2AttentionMLA(nn.Module): 
     | 
|
| 
       1473 
1491 
     | 
    
         
             
                    latent_cache = latent_cache.unsqueeze(1)
         
     | 
| 
       1474 
1492 
     | 
    
         
             
                    kv_a = self.kv_a_layernorm(kv_a)
         
     | 
| 
       1475 
1493 
     | 
    
         
             
                    k_pe = latent_cache[:, :, self.kv_lora_rank :]
         
     | 
| 
       1476 
     | 
    
         
            -
                     
     | 
| 
      
 1494 
     | 
    
         
            +
                    if self.rotary_emb is not None:
         
     | 
| 
      
 1495 
     | 
    
         
            +
                        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
         
     | 
| 
       1477 
1496 
     | 
    
         
             
                    q[..., self.qk_nope_head_dim :] = q_pe
         
     | 
| 
       1478 
1497 
     | 
    
         | 
| 
       1479 
1498 
     | 
    
         
             
                    self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch)
         
     | 
| 
         @@ -1632,8 +1651,10 @@ class DeepseekV2AttentionMLA(nn.Module): 
     | 
|
| 
       1632 
1651 
     | 
    
         | 
| 
       1633 
1652 
     | 
    
         
             
                    q_nope_out = q_nope_out.transpose(0, 1)
         
     | 
| 
       1634 
1653 
     | 
    
         | 
| 
       1635 
     | 
    
         
            -
                    if  
     | 
| 
       1636 
     | 
    
         
            -
                         
     | 
| 
      
 1654 
     | 
    
         
            +
                    if (
         
     | 
| 
      
 1655 
     | 
    
         
            +
                        self.rotary_emb is not None
         
     | 
| 
      
 1656 
     | 
    
         
            +
                        and (not self._fuse_rope_for_trtllm_mla(forward_batch))
         
     | 
| 
      
 1657 
     | 
    
         
            +
                        and (not _use_aiter or not _is_gfx95_supported or self.use_nsa)
         
     | 
| 
       1637 
1658 
     | 
    
         
             
                    ):
         
     | 
| 
       1638 
1659 
     | 
    
         
             
                        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
         
     | 
| 
       1639 
1660 
     | 
    
         | 
| 
         @@ -2828,6 +2849,7 @@ class DeepseekV2Model(nn.Module): 
     | 
|
| 
       2828 
2849 
     | 
    
         
             
                                self.embed_tokens.embedding_dim,
         
     | 
| 
       2829 
2850 
     | 
    
         
             
                            )
         
     | 
| 
       2830 
2851 
     | 
    
         
             
                        )
         
     | 
| 
      
 2852 
     | 
    
         
            +
                    self.layers_to_capture = []
         
     | 
| 
       2831 
2853 
     | 
    
         | 
| 
       2832 
2854 
     | 
    
         
             
                def get_input_embeddings(self) -> torch.Tensor:
         
     | 
| 
       2833 
2855 
     | 
    
         
             
                    return self.embed_tokens
         
     | 
| 
         @@ -2884,9 +2906,11 @@ class DeepseekV2Model(nn.Module): 
     | 
|
| 
       2884 
2906 
     | 
    
         
             
                            normal_end_layer = self.first_k_dense_replace
         
     | 
| 
       2885 
2907 
     | 
    
         
             
                        elif self.first_k_dense_replace < normal_start_layer:
         
     | 
| 
       2886 
2908 
     | 
    
         
             
                            normal_end_layer = normal_start_layer = 0
         
     | 
| 
       2887 
     | 
    
         
            -
             
     | 
| 
      
 2909 
     | 
    
         
            +
                    aux_hidden_states = []
         
     | 
| 
       2888 
2910 
     | 
    
         
             
                    for i in range(normal_start_layer, normal_end_layer):
         
     | 
| 
       2889 
2911 
     | 
    
         
             
                        with get_global_expert_distribution_recorder().with_current_layer(i):
         
     | 
| 
      
 2912 
     | 
    
         
            +
                            if i in self.layers_to_capture:
         
     | 
| 
      
 2913 
     | 
    
         
            +
                                aux_hidden_states.append(hidden_states + residual)
         
     | 
| 
       2890 
2914 
     | 
    
         
             
                            layer = self.layers[i]
         
     | 
| 
       2891 
2915 
     | 
    
         
             
                            hidden_states, residual = layer(
         
     | 
| 
       2892 
2916 
     | 
    
         
             
                                positions,
         
     | 
| 
         @@ -2924,7 +2948,9 @@ class DeepseekV2Model(nn.Module): 
     | 
|
| 
       2924 
2948 
     | 
    
         
             
                                hidden_states = self.norm(hidden_states)
         
     | 
| 
       2925 
2949 
     | 
    
         
             
                            else:
         
     | 
| 
       2926 
2950 
     | 
    
         
             
                                hidden_states, _ = self.norm(hidden_states, residual)
         
     | 
| 
       2927 
     | 
    
         
            -
                     
     | 
| 
      
 2951 
     | 
    
         
            +
                    if len(aux_hidden_states) == 0:
         
     | 
| 
      
 2952 
     | 
    
         
            +
                        return hidden_states
         
     | 
| 
      
 2953 
     | 
    
         
            +
                    return hidden_states, aux_hidden_states
         
     | 
| 
       2928 
2954 
     | 
    
         | 
| 
       2929 
2955 
     | 
    
         | 
| 
       2930 
2956 
     | 
    
         
             
            class DeepseekV2ForCausalLM(nn.Module):
         
     | 
| 
         @@ -2978,6 +3004,7 @@ class DeepseekV2ForCausalLM(nn.Module): 
     | 
|
| 
       2978 
3004 
     | 
    
         
             
                            if isinstance(layer.mlp, DeepseekV2MoE)
         
     | 
| 
       2979 
3005 
     | 
    
         
             
                        }
         
     | 
| 
       2980 
3006 
     | 
    
         
             
                    )
         
     | 
| 
      
 3007 
     | 
    
         
            +
                    self.capture_aux_hidden_states = False
         
     | 
| 
       2981 
3008 
     | 
    
         | 
| 
       2982 
3009 
     | 
    
         
             
                @property
         
     | 
| 
       2983 
3010 
     | 
    
         
             
                def routed_experts_weights_of_layer(self):
         
     | 
| 
         @@ -3002,7 +3029,7 @@ class DeepseekV2ForCausalLM(nn.Module): 
     | 
|
| 
       3002 
3029 
     | 
    
         
             
                        disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
         
     | 
| 
       3003 
3030 
     | 
    
         
             
                    elif get_moe_expert_parallel_world_size() > 1:
         
     | 
| 
       3004 
3031 
     | 
    
         
             
                        disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
         
     | 
| 
       3005 
     | 
    
         
            -
                    elif self.quant_config.get_name() == "w4afp8":
         
     | 
| 
      
 3032 
     | 
    
         
            +
                    elif self.quant_config and self.quant_config.get_name() == "w4afp8":
         
     | 
| 
       3006 
3033 
     | 
    
         
             
                        disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
         
     | 
| 
       3007 
3034 
     | 
    
         | 
| 
       3008 
3035 
     | 
    
         
             
                    if disable_reason is not None:
         
     | 
| 
         @@ -3031,10 +3058,13 @@ class DeepseekV2ForCausalLM(nn.Module): 
     | 
|
| 
       3031 
3058 
     | 
    
         
             
                    hidden_states = self.model(
         
     | 
| 
       3032 
3059 
     | 
    
         
             
                        input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
         
     | 
| 
       3033 
3060 
     | 
    
         
             
                    )
         
     | 
| 
      
 3061 
     | 
    
         
            +
                    aux_hidden_states = None
         
     | 
| 
      
 3062 
     | 
    
         
            +
                    if self.capture_aux_hidden_states:
         
     | 
| 
      
 3063 
     | 
    
         
            +
                        hidden_states, aux_hidden_states = hidden_states
         
     | 
| 
       3034 
3064 
     | 
    
         | 
| 
       3035 
3065 
     | 
    
         
             
                    if self.pp_group.is_last_rank:
         
     | 
| 
       3036 
3066 
     | 
    
         
             
                        return self.logits_processor(
         
     | 
| 
       3037 
     | 
    
         
            -
                            input_ids, hidden_states, self.lm_head, forward_batch
         
     | 
| 
      
 3067 
     | 
    
         
            +
                            input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
         
     | 
| 
       3038 
3068 
     | 
    
         
             
                        )
         
     | 
| 
       3039 
3069 
     | 
    
         
             
                    else:
         
     | 
| 
       3040 
3070 
     | 
    
         
             
                        return hidden_states
         
     | 
| 
         @@ -3293,8 +3323,8 @@ class DeepseekV2ForCausalLM(nn.Module): 
     | 
|
| 
       3293 
3323 
     | 
    
         
             
                            experts = layer.mlp.experts
         
     | 
| 
       3294 
3324 
     | 
    
         
             
                            if isinstance(experts, DeepEPMoE):
         
     | 
| 
       3295 
3325 
     | 
    
         
             
                                for w in [
         
     | 
| 
       3296 
     | 
    
         
            -
                                    experts. 
     | 
| 
       3297 
     | 
    
         
            -
                                    experts. 
     | 
| 
      
 3326 
     | 
    
         
            +
                                    (experts.w13_weight, experts.w13_weight_scale_inv),
         
     | 
| 
      
 3327 
     | 
    
         
            +
                                    (experts.w2_weight, experts.w2_weight_scale_inv),
         
     | 
| 
       3298 
3328 
     | 
    
         
             
                                ]:
         
     | 
| 
       3299 
3329 
     | 
    
         
             
                                    requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
         
     | 
| 
       3300 
3330 
     | 
    
         
             
                        else:
         
     | 
| 
         @@ -3342,10 +3372,26 @@ class DeepseekV2ForCausalLM(nn.Module): 
     | 
|
| 
       3342 
3372 
     | 
    
         
             
                            )
         
     | 
| 
       3343 
3373 
     | 
    
         | 
| 
       3344 
3374 
     | 
    
         
             
                    experts = layer.mlp.experts
         
     | 
| 
      
 3375 
     | 
    
         
            +
                    w13_weight_fp8 = (
         
     | 
| 
      
 3376 
     | 
    
         
            +
                        experts.w13_weight,
         
     | 
| 
      
 3377 
     | 
    
         
            +
                        (
         
     | 
| 
      
 3378 
     | 
    
         
            +
                            experts.w13_weight_scale_inv
         
     | 
| 
      
 3379 
     | 
    
         
            +
                            if hasattr(experts, "w13_weight_scale_inv")
         
     | 
| 
      
 3380 
     | 
    
         
            +
                            else experts.w13_weight_scale
         
     | 
| 
      
 3381 
     | 
    
         
            +
                        ),
         
     | 
| 
      
 3382 
     | 
    
         
            +
                    )
         
     | 
| 
      
 3383 
     | 
    
         
            +
                    w2_weight_fp8 = (
         
     | 
| 
      
 3384 
     | 
    
         
            +
                        experts.w2_weight,
         
     | 
| 
      
 3385 
     | 
    
         
            +
                        (
         
     | 
| 
      
 3386 
     | 
    
         
            +
                            experts.w2_weight_scale_inv
         
     | 
| 
      
 3387 
     | 
    
         
            +
                            if hasattr(experts, "w2_weight_scale_inv")
         
     | 
| 
      
 3388 
     | 
    
         
            +
                            else experts.w2_weight_scale
         
     | 
| 
      
 3389 
     | 
    
         
            +
                        ),
         
     | 
| 
      
 3390 
     | 
    
         
            +
                    )
         
     | 
| 
       3345 
3391 
     | 
    
         
             
                    if isinstance(experts, DeepEPMoE):
         
     | 
| 
       3346 
3392 
     | 
    
         
             
                        for w in [
         
     | 
| 
       3347 
     | 
    
         
            -
                             
     | 
| 
       3348 
     | 
    
         
            -
                             
     | 
| 
      
 3393 
     | 
    
         
            +
                            w13_weight_fp8,
         
     | 
| 
      
 3394 
     | 
    
         
            +
                            w2_weight_fp8,
         
     | 
| 
       3349 
3395 
     | 
    
         
             
                        ]:
         
     | 
| 
       3350 
3396 
     | 
    
         
             
                            transform_scale_ue8m0_inplace(w[1], mn=w[0].shape[-2])
         
     | 
| 
       3351 
3397 
     | 
    
         | 
| 
         @@ -3398,6 +3444,10 @@ class DeepseekV2ForCausalLM(nn.Module): 
     | 
|
| 
       3398 
3444 
     | 
    
         
             
                        self.config.q_lora_rank is not None
         
     | 
| 
       3399 
3445 
     | 
    
         
             
                    )
         
     | 
| 
       3400 
3446 
     | 
    
         
             
                    cached_a_proj = {} if fuse_qkv_a_proj else None
         
     | 
| 
      
 3447 
     | 
    
         
            +
                    fuse_wk_and_weights_proj = is_nsa_indexer_wk_and_weights_proj_fused(
         
     | 
| 
      
 3448 
     | 
    
         
            +
                        self.config, self.quant_config
         
     | 
| 
      
 3449 
     | 
    
         
            +
                    )
         
     | 
| 
      
 3450 
     | 
    
         
            +
                    cached_wk_and_weights_proj = {} if fuse_wk_and_weights_proj else None
         
     | 
| 
       3401 
3451 
     | 
    
         | 
| 
       3402 
3452 
     | 
    
         
             
                    if is_nextn:
         
     | 
| 
       3403 
3453 
     | 
    
         
             
                        nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
         
     | 
| 
         @@ -3569,6 +3619,53 @@ class DeepseekV2ForCausalLM(nn.Module): 
     | 
|
| 
       3569 
3619 
     | 
    
         
             
                                            )
         
     | 
| 
       3570 
3620 
     | 
    
         
             
                                            cached_a_proj.pop(q_a_proj_name)
         
     | 
| 
       3571 
3621 
     | 
    
         
             
                                            cached_a_proj.pop(kv_a_proj_name)
         
     | 
| 
      
 3622 
     | 
    
         
            +
                                    elif fuse_wk_and_weights_proj and (
         
     | 
| 
      
 3623 
     | 
    
         
            +
                                        "wk" in name or "weights_proj" in name
         
     | 
| 
      
 3624 
     | 
    
         
            +
                                    ):
         
     | 
| 
      
 3625 
     | 
    
         
            +
                                        cached_wk_and_weights_proj[name] = loaded_weight
         
     | 
| 
      
 3626 
     | 
    
         
            +
                                        wk_name = (
         
     | 
| 
      
 3627 
     | 
    
         
            +
                                            name
         
     | 
| 
      
 3628 
     | 
    
         
            +
                                            if "wk" in name
         
     | 
| 
      
 3629 
     | 
    
         
            +
                                            else name.replace("weights_proj", "wk")
         
     | 
| 
      
 3630 
     | 
    
         
            +
                                        )
         
     | 
| 
      
 3631 
     | 
    
         
            +
                                        weights_proj_name = (
         
     | 
| 
      
 3632 
     | 
    
         
            +
                                            name
         
     | 
| 
      
 3633 
     | 
    
         
            +
                                            if "weights_proj" in name
         
     | 
| 
      
 3634 
     | 
    
         
            +
                                            else name.replace("wk", "weights_proj")
         
     | 
| 
      
 3635 
     | 
    
         
            +
                                        )
         
     | 
| 
      
 3636 
     | 
    
         
            +
             
     | 
| 
      
 3637 
     | 
    
         
            +
                                        # When both wk and weights_proj has been cached, load the fused weight to parameter
         
     | 
| 
      
 3638 
     | 
    
         
            +
                                        if (
         
     | 
| 
      
 3639 
     | 
    
         
            +
                                            wk_name in cached_wk_and_weights_proj
         
     | 
| 
      
 3640 
     | 
    
         
            +
                                            and weights_proj_name in cached_wk_and_weights_proj
         
     | 
| 
      
 3641 
     | 
    
         
            +
                                        ):
         
     | 
| 
      
 3642 
     | 
    
         
            +
                                            wk_weight = cached_wk_and_weights_proj[wk_name]
         
     | 
| 
      
 3643 
     | 
    
         
            +
                                            weights_proj_weight = cached_wk_and_weights_proj[
         
     | 
| 
      
 3644 
     | 
    
         
            +
                                                weights_proj_name
         
     | 
| 
      
 3645 
     | 
    
         
            +
                                            ]
         
     | 
| 
      
 3646 
     | 
    
         
            +
                                            # todo dequantize wk for fp8
         
     | 
| 
      
 3647 
     | 
    
         
            +
                                            assert wk_weight.dtype == weights_proj_weight.dtype
         
     | 
| 
      
 3648 
     | 
    
         
            +
                                            fused_weight = torch.cat(
         
     | 
| 
      
 3649 
     | 
    
         
            +
                                                [wk_weight, weights_proj_weight], dim=0
         
     | 
| 
      
 3650 
     | 
    
         
            +
                                            )
         
     | 
| 
      
 3651 
     | 
    
         
            +
                                            param_name = (
         
     | 
| 
      
 3652 
     | 
    
         
            +
                                                name.replace("wk", "fused_wk_and_weights_proj")
         
     | 
| 
      
 3653 
     | 
    
         
            +
                                                if "wk" in name
         
     | 
| 
      
 3654 
     | 
    
         
            +
                                                else name.replace(
         
     | 
| 
      
 3655 
     | 
    
         
            +
                                                    "weights_proj",
         
     | 
| 
      
 3656 
     | 
    
         
            +
                                                    "fused_wk_and_weights_proj",
         
     | 
| 
      
 3657 
     | 
    
         
            +
                                                )
         
     | 
| 
      
 3658 
     | 
    
         
            +
                                            )
         
     | 
| 
      
 3659 
     | 
    
         
            +
                                            param = params_dict[param_name]
         
     | 
| 
      
 3660 
     | 
    
         
            +
             
     | 
| 
      
 3661 
     | 
    
         
            +
                                            weight_loader = getattr(
         
     | 
| 
      
 3662 
     | 
    
         
            +
                                                param, "weight_loader", default_weight_loader
         
     | 
| 
      
 3663 
     | 
    
         
            +
                                            )
         
     | 
| 
      
 3664 
     | 
    
         
            +
                                            futures.append(
         
     | 
| 
      
 3665 
     | 
    
         
            +
                                                executor.submit(weight_loader, param, fused_weight)
         
     | 
| 
      
 3666 
     | 
    
         
            +
                                            )
         
     | 
| 
      
 3667 
     | 
    
         
            +
                                            cached_wk_and_weights_proj.pop(wk_name)
         
     | 
| 
      
 3668 
     | 
    
         
            +
                                            cached_wk_and_weights_proj.pop(weights_proj_name)
         
     | 
| 
       3572 
3669 
     | 
    
         
             
                                    else:
         
     | 
| 
       3573 
3670 
     | 
    
         
             
                                        if (
         
     | 
| 
       3574 
3671 
     | 
    
         
             
                                            "k_scale" in name or "v_scale" in name
         
     | 
| 
         @@ -3664,8 +3761,12 @@ class DeepseekV2ForCausalLM(nn.Module): 
     | 
|
| 
       3664 
3761 
     | 
    
         
             
                    del self.lm_head.weight
         
     | 
| 
       3665 
3762 
     | 
    
         
             
                    self.model.embed_tokens.weight = embed
         
     | 
| 
       3666 
3763 
     | 
    
         
             
                    self.lm_head.weight = head
         
     | 
| 
       3667 
     | 
    
         
            -
                     
     | 
| 
       3668 
     | 
    
         
            -
             
     | 
| 
      
 3764 
     | 
    
         
            +
                    if not _is_npu:
         
     | 
| 
      
 3765 
     | 
    
         
            +
                        torch.cuda.empty_cache()
         
     | 
| 
      
 3766 
     | 
    
         
            +
                        torch.cuda.synchronize()
         
     | 
| 
      
 3767 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 3768 
     | 
    
         
            +
                        torch.npu.empty_cache()
         
     | 
| 
      
 3769 
     | 
    
         
            +
                        torch.npu.synchronize()
         
     | 
| 
       3669 
3770 
     | 
    
         | 
| 
       3670 
3771 
     | 
    
         
             
                @classmethod
         
     | 
| 
       3671 
3772 
     | 
    
         
             
                def get_model_config_for_expert_location(cls, config):
         
     | 
| 
         @@ -3675,6 +3776,20 @@ class DeepseekV2ForCausalLM(nn.Module): 
     | 
|
| 
       3675 
3776 
     | 
    
         
             
                        num_groups=config.n_group,
         
     | 
| 
       3676 
3777 
     | 
    
         
             
                    )
         
     | 
| 
       3677 
3778 
     | 
    
         | 
| 
      
 3779 
     | 
    
         
            +
                def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
         
     | 
| 
      
 3780 
     | 
    
         
            +
                    if not self.pp_group.is_last_rank:
         
     | 
| 
      
 3781 
     | 
    
         
            +
                        return
         
     | 
| 
      
 3782 
     | 
    
         
            +
             
     | 
| 
      
 3783 
     | 
    
         
            +
                    if layer_ids is None:
         
     | 
| 
      
 3784 
     | 
    
         
            +
                        self.capture_aux_hidden_states = True
         
     | 
| 
      
 3785 
     | 
    
         
            +
                        num_layers = self.config.num_hidden_layers
         
     | 
| 
      
 3786 
     | 
    
         
            +
                        self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
         
     | 
| 
      
 3787 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 3788 
     | 
    
         
            +
                        self.capture_aux_hidden_states = True
         
     | 
| 
      
 3789 
     | 
    
         
            +
                        # we plus 1 here because in sglang, for the ith layer, it takes the output
         
     | 
| 
      
 3790 
     | 
    
         
            +
                        # of the (i-1)th layer as aux hidden state
         
     | 
| 
      
 3791 
     | 
    
         
            +
                        self.model.layers_to_capture = [val + 1 for val in layer_ids]
         
     | 
| 
      
 3792 
     | 
    
         
            +
             
     | 
| 
       3678 
3793 
     | 
    
         | 
| 
       3679 
3794 
     | 
    
         
             
            AttentionBackendRegistry.register("ascend", handle_attention_ascend)
         
     | 
| 
       3680 
3795 
     | 
    
         
             
            AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
         
     |