sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +18 -3
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +25 -2
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -5
 - sglang/srt/entrypoints/engine.py +13 -5
 - sglang/srt/entrypoints/http_server.py +22 -3
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +7 -0
 - sglang/srt/eplb/expert_distribution.py +34 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +7 -6
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
 - sglang/srt/layers/communicator.py +23 -1
 - sglang/srt/layers/layernorm.py +16 -2
 - sglang/srt/layers/logits_processor.py +4 -20
 - sglang/srt/layers/moe/ep_moe/layer.py +0 -18
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
 - sglang/srt/layers/moe/topk.py +31 -6
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +9 -78
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/rotary_embedding.py +117 -45
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +26 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +164 -129
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +154 -59
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/memory_pool.py +171 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +11 -11
 - sglang/srt/model_executor/model_runner.py +76 -21
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +149 -34
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +0 -1
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +1 -1
 - sglang/srt/models/qwen2.py +1 -1
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +1 -1
 - sglang/srt/models/qwen3_moe.py +16 -8
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +103 -22
 - sglang/srt/single_batch_overlap.py +4 -1
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +55 -32
 - sglang/srt/utils/hf_transformers_utils.py +38 -16
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -29,7 +29,12 @@ from typing import Callable, List, Optional, Tuple, Union 
     | 
|
| 
       29 
29 
     | 
    
         
             
            import torch
         
     | 
| 
       30 
30 
     | 
    
         
             
            import torch.distributed as dist
         
     | 
| 
       31 
31 
     | 
    
         | 
| 
       32 
     | 
    
         
            -
            from sglang.srt.configs import  
     | 
| 
      
 32 
     | 
    
         
            +
            from sglang.srt.configs import (
         
     | 
| 
      
 33 
     | 
    
         
            +
                FalconH1Config,
         
     | 
| 
      
 34 
     | 
    
         
            +
                KimiLinearConfig,
         
     | 
| 
      
 35 
     | 
    
         
            +
                NemotronHConfig,
         
     | 
| 
      
 36 
     | 
    
         
            +
                Qwen3NextConfig,
         
     | 
| 
      
 37 
     | 
    
         
            +
            )
         
     | 
| 
       33 
38 
     | 
    
         
             
            from sglang.srt.configs.device_config import DeviceConfig
         
     | 
| 
       34 
39 
     | 
    
         
             
            from sglang.srt.configs.load_config import LoadConfig, LoadFormat
         
     | 
| 
       35 
40 
     | 
    
         
             
            from sglang.srt.configs.model_config import (
         
     | 
| 
         @@ -40,6 +45,9 @@ from sglang.srt.configs.model_config import ( 
     | 
|
| 
       40 
45 
     | 
    
         
             
            )
         
     | 
| 
       41 
46 
     | 
    
         
             
            from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
         
     | 
| 
       42 
47 
     | 
    
         
             
            from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
         
     | 
| 
      
 48 
     | 
    
         
            +
            from sglang.srt.debug_utils.tensor_dump_forward_hook import (
         
     | 
| 
      
 49 
     | 
    
         
            +
                register_forward_hook_for_model,
         
     | 
| 
      
 50 
     | 
    
         
            +
            )
         
     | 
| 
       43 
51 
     | 
    
         
             
            from sglang.srt.distributed import (
         
     | 
| 
       44 
52 
     | 
    
         
             
                get_pp_group,
         
     | 
| 
       45 
53 
     | 
    
         
             
                get_tp_group,
         
     | 
| 
         @@ -77,7 +85,6 @@ from sglang.srt.layers.dp_attention import ( 
     | 
|
| 
       77 
85 
     | 
    
         
             
                initialize_dp_attention,
         
     | 
| 
       78 
86 
     | 
    
         
             
            )
         
     | 
| 
       79 
87 
     | 
    
         
             
            from sglang.srt.layers.logits_processor import LogitsProcessorOutput
         
     | 
| 
       80 
     | 
    
         
            -
            from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
         
     | 
| 
       81 
88 
     | 
    
         
             
            from sglang.srt.layers.sampler import Sampler
         
     | 
| 
       82 
89 
     | 
    
         
             
            from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
         
     | 
| 
       83 
90 
     | 
    
         
             
            from sglang.srt.lora.lora_manager import LoRAManager
         
     | 
| 
         @@ -349,7 +356,11 @@ class ModelRunner: 
     | 
|
| 
       349 
356 
     | 
    
         | 
| 
       350 
357 
     | 
    
         
             
                    if not self.is_draft_worker:
         
     | 
| 
       351 
358 
     | 
    
         
             
                        set_global_expert_location_metadata(
         
     | 
| 
       352 
     | 
    
         
            -
                            compute_initial_expert_location_metadata( 
     | 
| 
      
 359 
     | 
    
         
            +
                            compute_initial_expert_location_metadata(
         
     | 
| 
      
 360 
     | 
    
         
            +
                                server_args=server_args,
         
     | 
| 
      
 361 
     | 
    
         
            +
                                model_config=self.model_config,
         
     | 
| 
      
 362 
     | 
    
         
            +
                                moe_ep_rank=self.moe_ep_rank,
         
     | 
| 
      
 363 
     | 
    
         
            +
                            )
         
     | 
| 
       353 
364 
     | 
    
         
             
                        )
         
     | 
| 
       354 
365 
     | 
    
         
             
                        if self.tp_rank == 0 and get_bool_env_var(
         
     | 
| 
       355 
366 
     | 
    
         
             
                            "SGLANG_LOG_EXPERT_LOCATION_METADATA"
         
     | 
| 
         @@ -730,7 +741,6 @@ class ModelRunner: 
     | 
|
| 
       730 
741 
     | 
    
         
             
                    # Load the model
         
     | 
| 
       731 
742 
     | 
    
         
             
                    # Remove monkey_patch when linear.py quant remove dependencies with vllm
         
     | 
| 
       732 
743 
     | 
    
         
             
                    monkey_patch_vllm_parallel_state()
         
     | 
| 
       733 
     | 
    
         
            -
                    monkey_patch_isinstance_for_vllm_base_layer()
         
     | 
| 
       734 
744 
     | 
    
         | 
| 
       735 
745 
     | 
    
         
             
                    with self.memory_saver_adapter.region(
         
     | 
| 
       736 
746 
     | 
    
         
             
                        GPU_MEMORY_TYPE_WEIGHTS,
         
     | 
| 
         @@ -742,7 +752,6 @@ class ModelRunner: 
     | 
|
| 
       742 
752 
     | 
    
         
             
                            device_config=DeviceConfig(self.device, self.gpu_id),
         
     | 
| 
       743 
753 
     | 
    
         
             
                        )
         
     | 
| 
       744 
754 
     | 
    
         
             
                    monkey_patch_vllm_parallel_state(reverse=True)
         
     | 
| 
       745 
     | 
    
         
            -
                    monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
         
     | 
| 
       746 
755 
     | 
    
         | 
| 
       747 
756 
     | 
    
         
             
                    get_offloader().post_init()
         
     | 
| 
       748 
757 
     | 
    
         | 
| 
         @@ -790,6 +799,15 @@ class ModelRunner: 
     | 
|
| 
       790 
799 
     | 
    
         
             
                        f"avail mem={after_avail_memory:.2f} GB, "
         
     | 
| 
       791 
800 
     | 
    
         
             
                        f"mem usage={self.weight_load_mem_usage:.2f} GB."
         
     | 
| 
       792 
801 
     | 
    
         
             
                    )
         
     | 
| 
      
 802 
     | 
    
         
            +
                    if self.server_args.debug_tensor_dump_output_folder is not None:
         
     | 
| 
      
 803 
     | 
    
         
            +
                        register_forward_hook_for_model(
         
     | 
| 
      
 804 
     | 
    
         
            +
                            self.model,
         
     | 
| 
      
 805 
     | 
    
         
            +
                            self.server_args.debug_tensor_dump_output_folder,
         
     | 
| 
      
 806 
     | 
    
         
            +
                            self.server_args.debug_tensor_dump_layers,
         
     | 
| 
      
 807 
     | 
    
         
            +
                            self.tp_size,
         
     | 
| 
      
 808 
     | 
    
         
            +
                            self.tp_rank,
         
     | 
| 
      
 809 
     | 
    
         
            +
                            self.pp_rank,
         
     | 
| 
      
 810 
     | 
    
         
            +
                        )
         
     | 
| 
       793 
811 
     | 
    
         | 
| 
       794 
812 
     | 
    
         
             
                    if self.server_args.elastic_ep_backend == "mooncake":
         
     | 
| 
       795 
813 
     | 
    
         
             
                        # Mooncake does not support `monitored_barrier`
         
     | 
| 
         @@ -1345,9 +1363,16 @@ class ModelRunner: 
     | 
|
| 
       1345 
1363 
     | 
    
         
             
                        return config
         
     | 
| 
       1346 
1364 
     | 
    
         
             
                    return None
         
     | 
| 
       1347 
1365 
     | 
    
         | 
| 
      
 1366 
     | 
    
         
            +
                @property
         
     | 
| 
      
 1367 
     | 
    
         
            +
                def kimi_linear_config(self):
         
     | 
| 
      
 1368 
     | 
    
         
            +
                    config = self.model_config.hf_config
         
     | 
| 
      
 1369 
     | 
    
         
            +
                    if isinstance(config, KimiLinearConfig):
         
     | 
| 
      
 1370 
     | 
    
         
            +
                        return config
         
     | 
| 
      
 1371 
     | 
    
         
            +
                    return None
         
     | 
| 
      
 1372 
     | 
    
         
            +
             
     | 
| 
       1348 
1373 
     | 
    
         
             
                @property
         
     | 
| 
       1349 
1374 
     | 
    
         
             
                def mambaish_config(self):
         
     | 
| 
       1350 
     | 
    
         
            -
                    return self.mamba2_config or self.hybrid_gdn_config
         
     | 
| 
      
 1375 
     | 
    
         
            +
                    return self.mamba2_config or self.hybrid_gdn_config or self.kimi_linear_config
         
     | 
| 
       1351 
1376 
     | 
    
         | 
| 
       1352 
1377 
     | 
    
         
             
                def set_num_token_hybrid(self):
         
     | 
| 
       1353 
1378 
     | 
    
         
             
                    if (
         
     | 
| 
         @@ -1658,9 +1683,11 @@ class ModelRunner: 
     | 
|
| 
       1658 
1683 
     | 
    
         
             
                                    get_attention_tp_size()
         
     | 
| 
       1659 
1684 
     | 
    
         
             
                                ),
         
     | 
| 
       1660 
1685 
     | 
    
         
             
                                head_dim=self.model_config.head_dim,
         
     | 
| 
       1661 
     | 
    
         
            -
                                layer_num=self. 
     | 
| 
      
 1686 
     | 
    
         
            +
                                layer_num=self.num_effective_layers,
         
     | 
| 
       1662 
1687 
     | 
    
         
             
                                device=self.device,
         
     | 
| 
       1663 
1688 
     | 
    
         
             
                                enable_memory_saver=self.server_args.enable_memory_saver,
         
     | 
| 
      
 1689 
     | 
    
         
            +
                                start_layer=self.start_layer,
         
     | 
| 
      
 1690 
     | 
    
         
            +
                                end_layer=self.end_layer,
         
     | 
| 
       1664 
1691 
     | 
    
         
             
                            )
         
     | 
| 
       1665 
1692 
     | 
    
         
             
                    elif self.use_mla_backend and is_nsa_model:
         
     | 
| 
       1666 
1693 
     | 
    
         
             
                        self.token_to_kv_pool = NSATokenToKVPool(
         
     | 
| 
         @@ -1676,7 +1703,7 @@ class ModelRunner: 
     | 
|
| 
       1676 
1703 
     | 
    
         
             
                            end_layer=self.end_layer,
         
     | 
| 
       1677 
1704 
     | 
    
         
             
                            index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
         
     | 
| 
       1678 
1705 
     | 
    
         
             
                        )
         
     | 
| 
       1679 
     | 
    
         
            -
                    elif self.use_mla_backend:
         
     | 
| 
      
 1706 
     | 
    
         
            +
                    elif self.use_mla_backend and not self.mambaish_config:
         
     | 
| 
       1680 
1707 
     | 
    
         
             
                        assert not is_nsa_model
         
     | 
| 
       1681 
1708 
     | 
    
         
             
                        self.token_to_kv_pool = MLATokenToKVPool(
         
     | 
| 
       1682 
1709 
     | 
    
         
             
                            self.max_total_num_tokens,
         
     | 
| 
         @@ -1720,6 +1747,12 @@ class ModelRunner: 
     | 
|
| 
       1720 
1747 
     | 
    
         
             
                                device=self.device,
         
     | 
| 
       1721 
1748 
     | 
    
         
             
                            )
         
     | 
| 
       1722 
1749 
     | 
    
         
             
                        elif config := self.mambaish_config:
         
     | 
| 
      
 1750 
     | 
    
         
            +
                            extra_args = {}
         
     | 
| 
      
 1751 
     | 
    
         
            +
                            if self.use_mla_backend:
         
     | 
| 
      
 1752 
     | 
    
         
            +
                                extra_args = {
         
     | 
| 
      
 1753 
     | 
    
         
            +
                                    "kv_lora_rank": self.model_config.kv_lora_rank,
         
     | 
| 
      
 1754 
     | 
    
         
            +
                                    "qk_rope_head_dim": self.model_config.qk_rope_head_dim,
         
     | 
| 
      
 1755 
     | 
    
         
            +
                                }
         
     | 
| 
       1723 
1756 
     | 
    
         
             
                            self.token_to_kv_pool = HybridLinearKVPool(
         
     | 
| 
       1724 
1757 
     | 
    
         
             
                                page_size=self.page_size,
         
     | 
| 
       1725 
1758 
     | 
    
         
             
                                size=self.max_total_num_tokens,
         
     | 
| 
         @@ -1735,6 +1768,8 @@ class ModelRunner: 
     | 
|
| 
       1735 
1768 
     | 
    
         
             
                                enable_kvcache_transpose=False,
         
     | 
| 
       1736 
1769 
     | 
    
         
             
                                device=self.device,
         
     | 
| 
       1737 
1770 
     | 
    
         
             
                                mamba_pool=self.req_to_token_pool.mamba_pool,
         
     | 
| 
      
 1771 
     | 
    
         
            +
                                use_mla=self.use_mla_backend,
         
     | 
| 
      
 1772 
     | 
    
         
            +
                                **extra_args,
         
     | 
| 
       1738 
1773 
     | 
    
         
             
                            )
         
     | 
| 
       1739 
1774 
     | 
    
         
             
                        else:
         
     | 
| 
       1740 
1775 
     | 
    
         
             
                            self.token_to_kv_pool = MHATokenToKVPool(
         
     | 
| 
         @@ -1750,6 +1785,7 @@ class ModelRunner: 
     | 
|
| 
       1750 
1785 
     | 
    
         
             
                                enable_memory_saver=self.server_args.enable_memory_saver,
         
     | 
| 
       1751 
1786 
     | 
    
         
             
                                start_layer=self.start_layer,
         
     | 
| 
       1752 
1787 
     | 
    
         
             
                                end_layer=self.end_layer,
         
     | 
| 
      
 1788 
     | 
    
         
            +
                                enable_alt_stream=not self.server_args.enable_pdmux,
         
     | 
| 
       1753 
1789 
     | 
    
         
             
                                enable_kv_cache_copy=(
         
     | 
| 
       1754 
1790 
     | 
    
         
             
                                    self.server_args.speculative_algorithm is not None
         
     | 
| 
       1755 
1791 
     | 
    
         
             
                                ),
         
     | 
| 
         @@ -1818,12 +1854,18 @@ class ModelRunner: 
     | 
|
| 
       1818 
1854 
     | 
    
         | 
| 
       1819 
1855 
     | 
    
         
             
                def init_attention_backend(self):
         
     | 
| 
       1820 
1856 
     | 
    
         
             
                    """Init attention kernel backend."""
         
     | 
| 
       1821 
     | 
    
         
            -
                    if self.server_args. 
     | 
| 
      
 1857 
     | 
    
         
            +
                    if self.server_args.enable_pdmux:
         
     | 
| 
      
 1858 
     | 
    
         
            +
                        self.attn_backend = self._get_attention_backend(init_new_workspace=True)
         
     | 
| 
      
 1859 
     | 
    
         
            +
                        self.decode_attn_backend_group = []
         
     | 
| 
      
 1860 
     | 
    
         
            +
                        for _ in range(self.server_args.sm_group_num):
         
     | 
| 
      
 1861 
     | 
    
         
            +
                            self.decode_attn_backend_group.append(self._get_attention_backend())
         
     | 
| 
      
 1862 
     | 
    
         
            +
                        self.decode_attn_backend = self.decode_attn_backend_group[0]
         
     | 
| 
      
 1863 
     | 
    
         
            +
                    elif self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
         
     | 
| 
       1822 
1864 
     | 
    
         
             
                        self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
         
     | 
| 
       1823 
1865 
     | 
    
         
             
                    else:
         
     | 
| 
       1824 
1866 
     | 
    
         
             
                        self.attn_backend = self._get_attention_backend()
         
     | 
| 
       1825 
1867 
     | 
    
         | 
| 
       1826 
     | 
    
         
            -
                def _get_attention_backend(self):
         
     | 
| 
      
 1868 
     | 
    
         
            +
                def _get_attention_backend(self, init_new_workspace: bool = False):
         
     | 
| 
       1827 
1869 
     | 
    
         
             
                    """Init attention kernel backend."""
         
     | 
| 
       1828 
1870 
     | 
    
         
             
                    self.prefill_attention_backend_str, self.decode_attention_backend_str = (
         
     | 
| 
       1829 
1871 
     | 
    
         
             
                        self.server_args.get_attention_backends()
         
     | 
| 
         @@ -1837,10 +1879,12 @@ class ModelRunner: 
     | 
|
| 
       1837 
1879 
     | 
    
         
             
                        attn_backend = HybridAttnBackend(
         
     | 
| 
       1838 
1880 
     | 
    
         
             
                            self,
         
     | 
| 
       1839 
1881 
     | 
    
         
             
                            decode_backend=self._get_attention_backend_from_str(
         
     | 
| 
       1840 
     | 
    
         
            -
                                self.decode_attention_backend_str
         
     | 
| 
      
 1882 
     | 
    
         
            +
                                self.decode_attention_backend_str,
         
     | 
| 
      
 1883 
     | 
    
         
            +
                                init_new_workspace=init_new_workspace,
         
     | 
| 
       1841 
1884 
     | 
    
         
             
                            ),
         
     | 
| 
       1842 
1885 
     | 
    
         
             
                            prefill_backend=self._get_attention_backend_from_str(
         
     | 
| 
       1843 
     | 
    
         
            -
                                self.prefill_attention_backend_str
         
     | 
| 
      
 1886 
     | 
    
         
            +
                                self.prefill_attention_backend_str,
         
     | 
| 
      
 1887 
     | 
    
         
            +
                                init_new_workspace=init_new_workspace,
         
     | 
| 
       1844 
1888 
     | 
    
         
             
                            ),
         
     | 
| 
       1845 
1889 
     | 
    
         
             
                        )
         
     | 
| 
       1846 
1890 
     | 
    
         
             
                        logger.info(
         
     | 
| 
         @@ -1854,7 +1898,8 @@ class ModelRunner: 
     | 
|
| 
       1854 
1898 
     | 
    
         
             
                        )
         
     | 
| 
       1855 
1899 
     | 
    
         
             
                    else:
         
     | 
| 
       1856 
1900 
     | 
    
         
             
                        attn_backend = self._get_attention_backend_from_str(
         
     | 
| 
       1857 
     | 
    
         
            -
                            self.server_args.attention_backend
         
     | 
| 
      
 1901 
     | 
    
         
            +
                            self.server_args.attention_backend,
         
     | 
| 
      
 1902 
     | 
    
         
            +
                            init_new_workspace=init_new_workspace,
         
     | 
| 
       1858 
1903 
     | 
    
         
             
                        )
         
     | 
| 
       1859 
1904 
     | 
    
         | 
| 
       1860 
1905 
     | 
    
         
             
                    (
         
     | 
| 
         @@ -1863,9 +1908,12 @@ class ModelRunner: 
     | 
|
| 
       1863 
1908 
     | 
    
         
             
                    ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
         
     | 
| 
       1864 
1909 
     | 
    
         
             
                    return attn_backend
         
     | 
| 
       1865 
1910 
     | 
    
         | 
| 
       1866 
     | 
    
         
            -
                def _get_attention_backend_from_str( 
     | 
| 
      
 1911 
     | 
    
         
            +
                def _get_attention_backend_from_str(
         
     | 
| 
      
 1912 
     | 
    
         
            +
                    self, backend_str: str, init_new_workspace: bool = False
         
     | 
| 
      
 1913 
     | 
    
         
            +
                ):
         
     | 
| 
       1867 
1914 
     | 
    
         
             
                    if backend_str not in ATTENTION_BACKENDS:
         
     | 
| 
       1868 
1915 
     | 
    
         
             
                        raise ValueError(f"Invalid attention backend: {backend_str}")
         
     | 
| 
      
 1916 
     | 
    
         
            +
                    self.init_new_workspace = init_new_workspace
         
     | 
| 
       1869 
1917 
     | 
    
         
             
                    full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
         
     | 
| 
       1870 
1918 
     | 
    
         
             
                    return attn_backend_wrapper(self, full_attention_backend)
         
     | 
| 
       1871 
1919 
     | 
    
         | 
| 
         @@ -1963,6 +2011,9 @@ class ModelRunner: 
     | 
|
| 
       1963 
2011 
     | 
    
         
             
                    device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
         
     | 
| 
       1964 
2012 
     | 
    
         
             
                    tensor_parallel(self.model, device_mesh)
         
     | 
| 
       1965 
2013 
     | 
    
         | 
| 
      
 2014 
     | 
    
         
            +
                def update_decode_attn_backend(self, stream_idx: int):
         
     | 
| 
      
 2015 
     | 
    
         
            +
                    self.decode_attn_backend = self.decode_attn_backend_group[stream_idx]
         
     | 
| 
      
 2016 
     | 
    
         
            +
             
     | 
| 
       1966 
2017 
     | 
    
         
             
                def forward_decode(
         
     | 
| 
       1967 
2018 
     | 
    
         
             
                    self,
         
     | 
| 
       1968 
2019 
     | 
    
         
             
                    forward_batch: ForwardBatch,
         
     | 
| 
         @@ -1970,7 +2021,11 @@ class ModelRunner: 
     | 
|
| 
       1970 
2021 
     | 
    
         
             
                    pp_proxy_tensors=None,
         
     | 
| 
       1971 
2022 
     | 
    
         
             
                ) -> LogitsProcessorOutput:
         
     | 
| 
       1972 
2023 
     | 
    
         
             
                    if not skip_attn_backend_init:
         
     | 
| 
       1973 
     | 
    
         
            -
                        self. 
     | 
| 
      
 2024 
     | 
    
         
            +
                        if self.server_args.enable_pdmux:
         
     | 
| 
      
 2025 
     | 
    
         
            +
                            self.decode_attn_backend.init_forward_metadata(forward_batch)
         
     | 
| 
      
 2026 
     | 
    
         
            +
                            forward_batch.attn_backend = self.decode_attn_backend
         
     | 
| 
      
 2027 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 2028 
     | 
    
         
            +
                            self.attn_backend.init_forward_metadata(forward_batch)
         
     | 
| 
       1974 
2029 
     | 
    
         
             
                    # FIXME: add pp_proxy_tensors arg to all models
         
     | 
| 
       1975 
2030 
     | 
    
         
             
                    kwargs = {}
         
     | 
| 
       1976 
2031 
     | 
    
         
             
                    if self.support_pp:
         
     | 
| 
         @@ -2108,18 +2163,18 @@ class ModelRunner: 
     | 
|
| 
       2108 
2163 
     | 
    
         
             
                            skip_attn_backend_init=skip_attn_backend_init,
         
     | 
| 
       2109 
2164 
     | 
    
         
             
                            pp_proxy_tensors=pp_proxy_tensors,
         
     | 
| 
       2110 
2165 
     | 
    
         
             
                        )
         
     | 
| 
       2111 
     | 
    
         
            -
                    elif forward_batch.forward_mode.is_extend():
         
     | 
| 
       2112 
     | 
    
         
            -
                        ret = self.forward_extend(
         
     | 
| 
       2113 
     | 
    
         
            -
                            forward_batch,
         
     | 
| 
       2114 
     | 
    
         
            -
                            skip_attn_backend_init=skip_attn_backend_init,
         
     | 
| 
       2115 
     | 
    
         
            -
                            pp_proxy_tensors=pp_proxy_tensors,
         
     | 
| 
       2116 
     | 
    
         
            -
                        )
         
     | 
| 
       2117 
2166 
     | 
    
         
             
                    elif forward_batch.forward_mode.is_split_prefill():
         
     | 
| 
       2118 
2167 
     | 
    
         
             
                        ret = self.forward_split_prefill(
         
     | 
| 
       2119 
2168 
     | 
    
         
             
                            forward_batch,
         
     | 
| 
       2120 
2169 
     | 
    
         
             
                            reinit_attn_backend=reinit_attn_backend,
         
     | 
| 
       2121 
2170 
     | 
    
         
             
                            forward_count=split_forward_count,
         
     | 
| 
       2122 
2171 
     | 
    
         
             
                        )
         
     | 
| 
      
 2172 
     | 
    
         
            +
                    elif forward_batch.forward_mode.is_extend():
         
     | 
| 
      
 2173 
     | 
    
         
            +
                        ret = self.forward_extend(
         
     | 
| 
      
 2174 
     | 
    
         
            +
                            forward_batch,
         
     | 
| 
      
 2175 
     | 
    
         
            +
                            skip_attn_backend_init=skip_attn_backend_init,
         
     | 
| 
      
 2176 
     | 
    
         
            +
                            pp_proxy_tensors=pp_proxy_tensors,
         
     | 
| 
      
 2177 
     | 
    
         
            +
                        )
         
     | 
| 
       2123 
2178 
     | 
    
         
             
                    elif forward_batch.forward_mode.is_idle():
         
     | 
| 
       2124 
2179 
     | 
    
         
             
                        ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
         
     | 
| 
       2125 
2180 
     | 
    
         
             
                    else:
         
     | 
| 
         @@ -75,9 +75,13 @@ class NPUGraphRunner(CudaGraphRunner): 
     | 
|
| 
       75 
75 
     | 
    
         | 
| 
       76 
76 
     | 
    
         
             
                    # Replay
         
     | 
| 
       77 
77 
     | 
    
         
             
                    if not is_deepseek_nsa(self.model_runner.model_config.hf_config):
         
     | 
| 
       78 
     | 
    
         
            -
                         
     | 
| 
       79 
     | 
    
         
            -
                             
     | 
| 
       80 
     | 
    
         
            -
             
     | 
| 
      
 78 
     | 
    
         
            +
                        if forward_batch.forward_mode.is_target_verify():
         
     | 
| 
      
 79 
     | 
    
         
            +
                            seq_lens_cpu = forward_batch.seq_lens.cpu() + self.num_tokens_per_bs
         
     | 
| 
      
 80 
     | 
    
         
            +
                            seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs)
         
     | 
| 
      
 81 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 82 
     | 
    
         
            +
                            seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
         
     | 
| 
      
 83 
     | 
    
         
            +
                                self.bs - self.raw_bs
         
     | 
| 
      
 84 
     | 
    
         
            +
                            )
         
     | 
| 
       81 
85 
     | 
    
         
             
                        thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
         
     | 
| 
       82 
86 
     | 
    
         
             
                        thread.start()
         
     | 
| 
       83 
87 
     | 
    
         
             
                        self.graphs[self.bs].replay()
         
     | 
| 
         @@ -238,7 +238,7 @@ def get_quant_config( 
     | 
|
| 
       238 
238 
     | 
    
         
             
                    if model_config.quantization == "bitsandbytes":
         
     | 
| 
       239 
239 
     | 
    
         
             
                        config["adapter_name_or_path"] = model_name_or_path
         
     | 
| 
       240 
240 
     | 
    
         
             
                    elif model_config.quantization.startswith("modelopt") and (
         
     | 
| 
       241 
     | 
    
         
            -
                        config 
     | 
| 
      
 241 
     | 
    
         
            +
                        config.get("producer", {}).get("name", "").startswith("modelopt")
         
     | 
| 
       242 
242 
     | 
    
         
             
                    ):
         
     | 
| 
       243 
243 
     | 
    
         
             
                        quant_algo = config["quantization"]["quant_algo"]
         
     | 
| 
       244 
244 
     | 
    
         
             
                        if quant_algo is None:
         
     | 
    
        sglang/srt/models/bailing_moe.py
    CHANGED
    
    | 
         @@ -420,14 +420,21 @@ class BailingMoEAttention(nn.Module): 
     | 
|
| 
       420 
420 
     | 
    
         
             
                    attn_tp_size = get_attention_tp_size()
         
     | 
| 
       421 
421 
     | 
    
         | 
| 
       422 
422 
     | 
    
         
             
                    assert self.total_num_heads % attn_tp_size == 0
         
     | 
| 
       423 
     | 
    
         
            -
                     
     | 
| 
      
 423 
     | 
    
         
            +
                    if self.total_kv_heads >= attn_tp_size:
         
     | 
| 
      
 424 
     | 
    
         
            +
                        # Number of KV heads is greater than TP size, so we partition
         
     | 
| 
      
 425 
     | 
    
         
            +
                        # the KV heads across multiple tensor parallel GPUs.
         
     | 
| 
      
 426 
     | 
    
         
            +
                        assert self.total_kv_heads % attn_tp_size == 0
         
     | 
| 
      
 427 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 428 
     | 
    
         
            +
                        # Number of KV heads is less than TP size, so we replicate
         
     | 
| 
      
 429 
     | 
    
         
            +
                        # the KV heads across multiple tensor parallel GPUs.
         
     | 
| 
      
 430 
     | 
    
         
            +
                        assert attn_tp_size % self.total_kv_heads == 0
         
     | 
| 
       424 
431 
     | 
    
         
             
                    assert self.total_num_heads >= self.total_kv_heads
         
     | 
| 
       425 
432 
     | 
    
         | 
| 
       426 
433 
     | 
    
         
             
                    self.num_heads = self.total_num_heads // attn_tp_size
         
     | 
| 
       427 
434 
     | 
    
         
             
                    self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
         
     | 
| 
       428 
435 
     | 
    
         
             
                    self.q_size = self.head_dim * self.num_heads
         
     | 
| 
       429 
436 
     | 
    
         | 
| 
       430 
     | 
    
         
            -
                    self.num_kv_heads = self.total_kv_heads // attn_tp_size
         
     | 
| 
      
 437 
     | 
    
         
            +
                    self.num_kv_heads = max(1, self.total_kv_heads // attn_tp_size)
         
     | 
| 
       431 
438 
     | 
    
         
             
                    self.kv_size = max(1, self.num_kv_heads * self.head_dim)
         
     | 
| 
       432 
439 
     | 
    
         | 
| 
       433 
440 
     | 
    
         
             
                    self.scale = self.head_dim**-0.5
         
     | 
| 
         @@ -38,12 +38,13 @@ from sglang.srt.models.deepseek_v2 import ( 
     | 
|
| 
       38 
38 
     | 
    
         
             
                enable_nextn_moe_bf16_cast_to_fp8,
         
     | 
| 
       39 
39 
     | 
    
         
             
            )
         
     | 
| 
       40 
40 
     | 
    
         
             
            from sglang.srt.server_args import get_global_server_args
         
     | 
| 
       41 
     | 
    
         
            -
            from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
         
     | 
| 
      
 41 
     | 
    
         
            +
            from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_npu
         
     | 
| 
       42 
42 
     | 
    
         | 
| 
       43 
43 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
       44 
44 
     | 
    
         | 
| 
       45 
45 
     | 
    
         | 
| 
       46 
46 
     | 
    
         
             
            _is_cuda = is_cuda()
         
     | 
| 
      
 47 
     | 
    
         
            +
            _is_npu = is_npu()
         
     | 
| 
       47 
48 
     | 
    
         | 
| 
       48 
49 
     | 
    
         | 
| 
       49 
50 
     | 
    
         
             
            class DeepseekModelNextN(nn.Module):
         
     | 
| 
         @@ -85,13 +86,21 @@ class DeepseekModelNextN(nn.Module): 
     | 
|
| 
       85 
86 
     | 
    
         
             
                    self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
         
     | 
| 
       86 
87 
     | 
    
         | 
| 
       87 
88 
     | 
    
         
             
                    self.alt_stream = torch.cuda.Stream() if _is_cuda else None
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
                    layer_name = "decoder"
         
     | 
| 
      
 91 
     | 
    
         
            +
                    if _is_npu and (
         
     | 
| 
      
 92 
     | 
    
         
            +
                        get_global_server_args().speculative_draft_model_path
         
     | 
| 
      
 93 
     | 
    
         
            +
                        == get_global_server_args().model_path
         
     | 
| 
      
 94 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 95 
     | 
    
         
            +
                        layer_name = "layers." + str(config.num_hidden_layers)
         
     | 
| 
      
 96 
     | 
    
         
            +
             
     | 
| 
       88 
97 
     | 
    
         
             
                    self.decoder = DeepseekV2DecoderLayer(
         
     | 
| 
       89 
98 
     | 
    
         
             
                        config,
         
     | 
| 
       90 
99 
     | 
    
         
             
                        0,
         
     | 
| 
       91 
100 
     | 
    
         
             
                        quant_config=quant_config,
         
     | 
| 
       92 
101 
     | 
    
         
             
                        moe_quant_config=moe_quant_config,
         
     | 
| 
       93 
102 
     | 
    
         
             
                        is_nextn=True,
         
     | 
| 
       94 
     | 
    
         
            -
                        prefix=add_prefix( 
     | 
| 
      
 103 
     | 
    
         
            +
                        prefix=add_prefix(layer_name, prefix),
         
     | 
| 
       95 
104 
     | 
    
         
             
                        alt_stream=self.alt_stream,
         
     | 
| 
       96 
105 
     | 
    
         
             
                    )
         
     | 
| 
       97 
106 
     | 
    
         |