sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +18 -3
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +25 -2
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -5
 - sglang/srt/entrypoints/engine.py +13 -5
 - sglang/srt/entrypoints/http_server.py +22 -3
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +7 -0
 - sglang/srt/eplb/expert_distribution.py +34 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +7 -6
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
 - sglang/srt/layers/communicator.py +23 -1
 - sglang/srt/layers/layernorm.py +16 -2
 - sglang/srt/layers/logits_processor.py +4 -20
 - sglang/srt/layers/moe/ep_moe/layer.py +0 -18
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
 - sglang/srt/layers/moe/topk.py +31 -6
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +9 -78
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/rotary_embedding.py +117 -45
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +26 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +164 -129
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +154 -59
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/memory_pool.py +171 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +11 -11
 - sglang/srt/model_executor/model_runner.py +76 -21
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +149 -34
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +0 -1
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +1 -1
 - sglang/srt/models/qwen2.py +1 -1
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +1 -1
 - sglang/srt/models/qwen3_moe.py +16 -8
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +103 -22
 - sglang/srt/single_batch_overlap.py +4 -1
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +55 -32
 - sglang/srt/utils/hf_transformers_utils.py +38 -16
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -1,7 +1,7 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            import json
         
     | 
| 
       2 
2 
     | 
    
         
             
            from abc import ABC, abstractmethod
         
     | 
| 
       3 
3 
     | 
    
         
             
            from functools import lru_cache
         
     | 
| 
       4 
     | 
    
         
            -
            from typing import TYPE_CHECKING, Any, Dict, List, Optional
         
     | 
| 
      
 4 
     | 
    
         
            +
            from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
         
     | 
| 
       5 
5 
     | 
    
         | 
| 
       6 
6 
     | 
    
         
             
            import dill
         
     | 
| 
       7 
7 
     | 
    
         
             
            import orjson
         
     | 
| 
         @@ -126,3 +126,69 @@ class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor): 
     | 
|
| 
       126 
126 
     | 
    
         
             
                THINKING_START_TOKEN_ID: int = 128798
         
     | 
| 
       127 
127 
     | 
    
         
             
                THINKING_END_TOKEN_ID: int = 128799
         
     | 
| 
       128 
128 
     | 
    
         
             
                NEW_LINE_TOKEN_ID: int = 201
         
     | 
| 
      
 129 
     | 
    
         
            +
             
     | 
| 
      
 130 
     | 
    
         
            +
             
     | 
| 
      
 131 
     | 
    
         
            +
            # Adapted from DeepSeek's implementation: https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/ngram_norepeat.py
         
     | 
| 
      
 132 
     | 
    
         
            +
            class DeepseekOCRNoRepeatNGramLogitProcessor(CustomLogitProcessor):
         
     | 
| 
      
 133 
     | 
    
         
            +
                """Block n-gram repetitions within a sliding window for DeepSeek-OCR outputs."""
         
     | 
| 
      
 134 
     | 
    
         
            +
             
     | 
| 
      
 135 
     | 
    
         
            +
                def __call__(
         
     | 
| 
      
 136 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 137 
     | 
    
         
            +
                    logits: torch.Tensor,
         
     | 
| 
      
 138 
     | 
    
         
            +
                    custom_param_list: Optional[List[Dict[str, Any]]] = None,
         
     | 
| 
      
 139 
     | 
    
         
            +
                ) -> torch.Tensor:
         
     | 
| 
      
 140 
     | 
    
         
            +
                    if not custom_param_list:
         
     | 
| 
      
 141 
     | 
    
         
            +
                        return logits
         
     | 
| 
      
 142 
     | 
    
         
            +
             
     | 
| 
      
 143 
     | 
    
         
            +
                    for batch_idx, params in enumerate(custom_param_list):
         
     | 
| 
      
 144 
     | 
    
         
            +
                        if not params:
         
     | 
| 
      
 145 
     | 
    
         
            +
                            continue
         
     | 
| 
      
 146 
     | 
    
         
            +
             
     | 
| 
      
 147 
     | 
    
         
            +
                        req = params.get("__req__")
         
     | 
| 
      
 148 
     | 
    
         
            +
                        if req is None:
         
     | 
| 
      
 149 
     | 
    
         
            +
                            continue
         
     | 
| 
      
 150 
     | 
    
         
            +
             
     | 
| 
      
 151 
     | 
    
         
            +
                        try:
         
     | 
| 
      
 152 
     | 
    
         
            +
                            ngram_size = int(params.get("ngram_size") or 0)
         
     | 
| 
      
 153 
     | 
    
         
            +
                            window_size = int(params.get("window_size") or 0)
         
     | 
| 
      
 154 
     | 
    
         
            +
                        except (TypeError, ValueError):
         
     | 
| 
      
 155 
     | 
    
         
            +
                            continue
         
     | 
| 
      
 156 
     | 
    
         
            +
             
     | 
| 
      
 157 
     | 
    
         
            +
                        if ngram_size <= 0 or window_size <= 0:
         
     | 
| 
      
 158 
     | 
    
         
            +
                            continue
         
     | 
| 
      
 159 
     | 
    
         
            +
             
     | 
| 
      
 160 
     | 
    
         
            +
                        sequence: List[int] = req.origin_input_ids + req.output_ids
         
     | 
| 
      
 161 
     | 
    
         
            +
                        if len(sequence) < ngram_size:
         
     | 
| 
      
 162 
     | 
    
         
            +
                            continue
         
     | 
| 
      
 163 
     | 
    
         
            +
             
     | 
| 
      
 164 
     | 
    
         
            +
                        search_start = max(0, len(sequence) - window_size)
         
     | 
| 
      
 165 
     | 
    
         
            +
                        search_end = len(sequence) - ngram_size + 1
         
     | 
| 
      
 166 
     | 
    
         
            +
                        if search_end <= search_start:
         
     | 
| 
      
 167 
     | 
    
         
            +
                            continue
         
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
                        if ngram_size > 1:
         
     | 
| 
      
 170 
     | 
    
         
            +
                            current_prefix = tuple(sequence[-(ngram_size - 1) :])
         
     | 
| 
      
 171 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 172 
     | 
    
         
            +
                            current_prefix = tuple()
         
     | 
| 
      
 173 
     | 
    
         
            +
             
     | 
| 
      
 174 
     | 
    
         
            +
                        banned_tokens: Set[int] = set()
         
     | 
| 
      
 175 
     | 
    
         
            +
                        for idx in range(search_start, search_end):
         
     | 
| 
      
 176 
     | 
    
         
            +
                            ngram = sequence[idx : idx + ngram_size]
         
     | 
| 
      
 177 
     | 
    
         
            +
                            if ngram_size == 1 or tuple(ngram[:-1]) == current_prefix:
         
     | 
| 
      
 178 
     | 
    
         
            +
                                banned_tokens.add(ngram[-1])
         
     | 
| 
      
 179 
     | 
    
         
            +
             
     | 
| 
      
 180 
     | 
    
         
            +
                        whitelist_ids = params.get("whitelist_token_ids") or []
         
     | 
| 
      
 181 
     | 
    
         
            +
                        try:
         
     | 
| 
      
 182 
     | 
    
         
            +
                            whitelist = {int(token_id) for token_id in whitelist_ids}
         
     | 
| 
      
 183 
     | 
    
         
            +
                        except (TypeError, ValueError):
         
     | 
| 
      
 184 
     | 
    
         
            +
                            whitelist = set()
         
     | 
| 
      
 185 
     | 
    
         
            +
             
     | 
| 
      
 186 
     | 
    
         
            +
                        banned_tokens.difference_update(whitelist)
         
     | 
| 
      
 187 
     | 
    
         
            +
             
     | 
| 
      
 188 
     | 
    
         
            +
                        if not banned_tokens:
         
     | 
| 
      
 189 
     | 
    
         
            +
                            continue
         
     | 
| 
      
 190 
     | 
    
         
            +
             
     | 
| 
      
 191 
     | 
    
         
            +
                        indices = list(banned_tokens)
         
     | 
| 
      
 192 
     | 
    
         
            +
                        logits[batch_idx, indices] = -float("inf")
         
     | 
| 
      
 193 
     | 
    
         
            +
             
     | 
| 
      
 194 
     | 
    
         
            +
                    return logits
         
     | 
| 
         @@ -1,9 +1,6 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            import torch
         
     | 
| 
       2 
2 
     | 
    
         | 
| 
       3 
     | 
    
         
            -
            from sglang.srt.sampling.penaltylib.orchestrator import  
     | 
| 
       4 
     | 
    
         
            -
                BatchedPenalizerOrchestrator,
         
     | 
| 
       5 
     | 
    
         
            -
                _BatchedPenalizer,
         
     | 
| 
       6 
     | 
    
         
            -
            )
         
     | 
| 
      
 3 
     | 
    
         
            +
            from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
         
     | 
| 
       7 
4 
     | 
    
         | 
| 
       8 
5 
     | 
    
         | 
| 
       9 
6 
     | 
    
         
             
            class BatchedFrequencyPenalizer(_BatchedPenalizer):
         
     | 
| 
         @@ -11,10 +8,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer): 
     | 
|
| 
       11 
8 
     | 
    
         
             
                Frequency penalizer penalizes tokens based on their frequency in the output.
         
     | 
| 
       12 
9 
     | 
    
         
             
                """
         
     | 
| 
       13 
10 
     | 
    
         | 
| 
       14 
     | 
    
         
            -
                def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
         
     | 
| 
       15 
     | 
    
         
            -
                    self.orchestrator = orchestrator
         
     | 
| 
       16 
     | 
    
         
            -
                    self._is_prepared = False
         
     | 
| 
       17 
     | 
    
         
            -
             
     | 
| 
       18 
11 
     | 
    
         
             
                def _is_required(self) -> bool:
         
     | 
| 
       19 
12 
     | 
    
         
             
                    return any(
         
     | 
| 
       20 
13 
     | 
    
         
             
                        req.sampling_params.frequency_penalty != 0.0
         
     | 
| 
         @@ -63,3 +56,8 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer): 
     | 
|
| 
       63 
56 
     | 
    
         
             
                        [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
         
     | 
| 
       64 
57 
     | 
    
         
             
                        dim=0,
         
     | 
| 
       65 
58 
     | 
    
         
             
                    )
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
                def _teardown(self) -> None:
         
     | 
| 
      
 61 
     | 
    
         
            +
                    for name in ("frequency_penalties", "cumulated_frequency_penalties"):
         
     | 
| 
      
 62 
     | 
    
         
            +
                        if hasattr(self, name):
         
     | 
| 
      
 63 
     | 
    
         
            +
                            delattr(self, name)
         
     | 
| 
         @@ -1,9 +1,6 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            import torch
         
     | 
| 
       2 
2 
     | 
    
         | 
| 
       3 
     | 
    
         
            -
            from sglang.srt.sampling.penaltylib.orchestrator import  
     | 
| 
       4 
     | 
    
         
            -
                BatchedPenalizerOrchestrator,
         
     | 
| 
       5 
     | 
    
         
            -
                _BatchedPenalizer,
         
     | 
| 
       6 
     | 
    
         
            -
            )
         
     | 
| 
      
 3 
     | 
    
         
            +
            from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
         
     | 
| 
       7 
4 
     | 
    
         | 
| 
       8 
5 
     | 
    
         | 
| 
       9 
6 
     | 
    
         
             
            class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
         
     | 
| 
         @@ -11,10 +8,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): 
     | 
|
| 
       11 
8 
     | 
    
         
             
                Min new tokens penalizer penalizes tokens based on the length of the output.
         
     | 
| 
       12 
9 
     | 
    
         
             
                """
         
     | 
| 
       13 
10 
     | 
    
         | 
| 
       14 
     | 
    
         
            -
                def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
         
     | 
| 
       15 
     | 
    
         
            -
                    self.orchestrator = orchestrator
         
     | 
| 
       16 
     | 
    
         
            -
                    self._is_prepared = False
         
     | 
| 
       17 
     | 
    
         
            -
             
     | 
| 
       18 
11 
     | 
    
         
             
                def _is_required(self) -> bool:
         
     | 
| 
       19 
12 
     | 
    
         
             
                    return any(
         
     | 
| 
       20 
13 
     | 
    
         
             
                        req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()
         
     | 
| 
         @@ -92,3 +85,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): 
     | 
|
| 
       92 
85 
     | 
    
         
             
                    self.len_output_tokens = torch.cat(
         
     | 
| 
       93 
86 
     | 
    
         
             
                        [self.len_output_tokens, their.len_output_tokens], dim=0
         
     | 
| 
       94 
87 
     | 
    
         
             
                    )
         
     | 
| 
      
 88 
     | 
    
         
            +
             
     | 
| 
      
 89 
     | 
    
         
            +
                # Explicit resource cleanup to aid GC and free CUDA memory promptly
         
     | 
| 
      
 90 
     | 
    
         
            +
                def _teardown(self) -> None:
         
     | 
| 
      
 91 
     | 
    
         
            +
                    for name in ("min_new_tokens", "stop_token_penalties", "len_output_tokens"):
         
     | 
| 
      
 92 
     | 
    
         
            +
                        if hasattr(self, name):
         
     | 
| 
      
 93 
     | 
    
         
            +
                            delattr(self, name)
         
     | 
| 
         @@ -77,9 +77,8 @@ class BatchedPenalizerOrchestrator: 
     | 
|
| 
       77 
77 
     | 
    
         
             
                        return
         
     | 
| 
       78 
78 
     | 
    
         | 
| 
       79 
79 
     | 
    
         
             
                    if len(keep_indices) == 0:
         
     | 
| 
       80 
     | 
    
         
            -
                         
     | 
| 
       81 
     | 
    
         
            -
                         
     | 
| 
       82 
     | 
    
         
            -
                            penalizer.teardown()
         
     | 
| 
      
 80 
     | 
    
         
            +
                        # No requests left in the batch, fully release orchestrator resources
         
     | 
| 
      
 81 
     | 
    
         
            +
                        self.release()
         
     | 
| 
       83 
82 
     | 
    
         
             
                        return
         
     | 
| 
       84 
83 
     | 
    
         | 
| 
       85 
84 
     | 
    
         
             
                    is_required = False
         
     | 
| 
         @@ -92,6 +91,23 @@ class BatchedPenalizerOrchestrator: 
     | 
|
| 
       92 
91 
     | 
    
         
             
                            penalizer.teardown()
         
     | 
| 
       93 
92 
     | 
    
         
             
                    self.is_required = is_required
         
     | 
| 
       94 
93 
     | 
    
         | 
| 
      
 94 
     | 
    
         
            +
                # Resource management helpers
         
     | 
| 
      
 95 
     | 
    
         
            +
                def release(self) -> None:
         
     | 
| 
      
 96 
     | 
    
         
            +
                    """Release all penalizers and break references so GC can reclaim promptly."""
         
     | 
| 
      
 97 
     | 
    
         
            +
                    for penalizer in self.penalizers.values():
         
     | 
| 
      
 98 
     | 
    
         
            +
                        penalizer.teardown()
         
     | 
| 
      
 99 
     | 
    
         
            +
                    self.penalizers.clear()
         
     | 
| 
      
 100 
     | 
    
         
            +
                    # Break reference to ScheduleBatch
         
     | 
| 
      
 101 
     | 
    
         
            +
                    self._batch_ref = None
         
     | 
| 
      
 102 
     | 
    
         
            +
                    self.is_required = False
         
     | 
| 
      
 103 
     | 
    
         
            +
             
     | 
| 
      
 104 
     | 
    
         
            +
                # Context manager support
         
     | 
| 
      
 105 
     | 
    
         
            +
                def __enter__(self) -> "BatchedPenalizerOrchestrator":
         
     | 
| 
      
 106 
     | 
    
         
            +
                    return self
         
     | 
| 
      
 107 
     | 
    
         
            +
             
     | 
| 
      
 108 
     | 
    
         
            +
                def __exit__(self, exc_type, exc, tb) -> None:
         
     | 
| 
      
 109 
     | 
    
         
            +
                    self.release()
         
     | 
| 
      
 110 
     | 
    
         
            +
             
     | 
| 
       95 
111 
     | 
    
         
             
                def merge(self, their: "BatchedPenalizerOrchestrator"):
         
     | 
| 
       96 
112 
     | 
    
         
             
                    """
         
     | 
| 
       97 
113 
     | 
    
         
             
                    Merge the penalizers of another orchestrator into this one.
         
     | 
| 
         @@ -116,6 +132,22 @@ class _BatchedPenalizer(abc.ABC): 
     | 
|
| 
       116 
132 
     | 
    
         
             
                An abstract class for a batched penalizer.
         
     | 
| 
       117 
133 
     | 
    
         
             
                """
         
     | 
| 
       118 
134 
     | 
    
         | 
| 
      
 135 
     | 
    
         
            +
                def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
         
     | 
| 
      
 136 
     | 
    
         
            +
                    self._orchestrator_ref: weakref.ReferenceType[BatchedPenalizerOrchestrator] = (
         
     | 
| 
      
 137 
     | 
    
         
            +
                        weakref.ref(orchestrator)
         
     | 
| 
      
 138 
     | 
    
         
            +
                    )
         
     | 
| 
      
 139 
     | 
    
         
            +
                    self._is_prepared = False
         
     | 
| 
      
 140 
     | 
    
         
            +
             
     | 
| 
      
 141 
     | 
    
         
            +
                @property
         
     | 
| 
      
 142 
     | 
    
         
            +
                def orchestrator(self) -> BatchedPenalizerOrchestrator:
         
     | 
| 
      
 143 
     | 
    
         
            +
                    orch: Optional[BatchedPenalizerOrchestrator] = self._orchestrator_ref()
         
     | 
| 
      
 144 
     | 
    
         
            +
                    # This should never happen, but we need to handle it gracefully
         
     | 
| 
      
 145 
     | 
    
         
            +
                    if orch is None:
         
     | 
| 
      
 146 
     | 
    
         
            +
                        raise RuntimeError(
         
     | 
| 
      
 147 
     | 
    
         
            +
                            "BatchedPenalizerOrchestrator has been garbage-collected"
         
     | 
| 
      
 148 
     | 
    
         
            +
                        )
         
     | 
| 
      
 149 
     | 
    
         
            +
                    return orch
         
     | 
| 
      
 150 
     | 
    
         
            +
             
     | 
| 
       119 
151 
     | 
    
         
             
                def is_prepared(self) -> bool:
         
     | 
| 
       120 
152 
     | 
    
         
             
                    return self._is_prepared
         
     | 
| 
       121 
153 
     | 
    
         | 
| 
         @@ -135,6 +167,7 @@ class _BatchedPenalizer(abc.ABC): 
     | 
|
| 
       135 
167 
     | 
    
         
             
                        return False
         
     | 
| 
       136 
168 
     | 
    
         | 
| 
       137 
169 
     | 
    
         
             
                def teardown(self):
         
     | 
| 
      
 170 
     | 
    
         
            +
                    self._teardown()
         
     | 
| 
       138 
171 
     | 
    
         
             
                    self._is_prepared = False
         
     | 
| 
       139 
172 
     | 
    
         | 
| 
       140 
173 
     | 
    
         
             
                def cumulate_output_tokens(self, output_ids: torch.Tensor):
         
     | 
| 
         @@ -207,3 +240,10 @@ class _BatchedPenalizer(abc.ABC): 
     | 
|
| 
       207 
240 
     | 
    
         
             
                    Merge the penalizer with another penalizer.
         
     | 
| 
       208 
241 
     | 
    
         
             
                    """
         
     | 
| 
       209 
242 
     | 
    
         
             
                    pass
         
     | 
| 
      
 243 
     | 
    
         
            +
             
     | 
| 
      
 244 
     | 
    
         
            +
                @abc.abstractmethod
         
     | 
| 
      
 245 
     | 
    
         
            +
                def _teardown(self):
         
     | 
| 
      
 246 
     | 
    
         
            +
                    """
         
     | 
| 
      
 247 
     | 
    
         
            +
                    Teardown the penalizer.
         
     | 
| 
      
 248 
     | 
    
         
            +
                    """
         
     | 
| 
      
 249 
     | 
    
         
            +
                    pass
         
     | 
| 
         @@ -1,9 +1,6 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            import torch
         
     | 
| 
       2 
2 
     | 
    
         | 
| 
       3 
     | 
    
         
            -
            from sglang.srt.sampling.penaltylib.orchestrator import  
     | 
| 
       4 
     | 
    
         
            -
                BatchedPenalizerOrchestrator,
         
     | 
| 
       5 
     | 
    
         
            -
                _BatchedPenalizer,
         
     | 
| 
       6 
     | 
    
         
            -
            )
         
     | 
| 
      
 3 
     | 
    
         
            +
            from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
         
     | 
| 
       7 
4 
     | 
    
         | 
| 
       8 
5 
     | 
    
         | 
| 
       9 
6 
     | 
    
         
             
            class BatchedPresencePenalizer(_BatchedPenalizer):
         
     | 
| 
         @@ -11,10 +8,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer): 
     | 
|
| 
       11 
8 
     | 
    
         
             
                Presence penalizer penalizes tokens based on their presence in the output.
         
     | 
| 
       12 
9 
     | 
    
         
             
                """
         
     | 
| 
       13 
10 
     | 
    
         | 
| 
       14 
     | 
    
         
            -
                def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
         
     | 
| 
       15 
     | 
    
         
            -
                    self.orchestrator = orchestrator
         
     | 
| 
       16 
     | 
    
         
            -
                    self._is_prepared = False
         
     | 
| 
       17 
     | 
    
         
            -
             
     | 
| 
       18 
11 
     | 
    
         
             
                def _is_required(self) -> bool:
         
     | 
| 
       19 
12 
     | 
    
         
             
                    return any(
         
     | 
| 
       20 
13 
     | 
    
         
             
                        req.sampling_params.presence_penalty != 0.0
         
     | 
| 
         @@ -63,3 +56,8 @@ class BatchedPresencePenalizer(_BatchedPenalizer): 
     | 
|
| 
       63 
56 
     | 
    
         
             
                        [self.cumulated_presence_penalties, their.cumulated_presence_penalties],
         
     | 
| 
       64 
57 
     | 
    
         
             
                        dim=0,
         
     | 
| 
       65 
58 
     | 
    
         
             
                    )
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
                def _teardown(self) -> None:
         
     | 
| 
      
 61 
     | 
    
         
            +
                    for name in ("presence_penalties", "cumulated_presence_penalties"):
         
     | 
| 
      
 62 
     | 
    
         
            +
                        if hasattr(self, name):
         
     | 
| 
      
 63 
     | 
    
         
            +
                            delattr(self, name)
         
     | 
    
        sglang/srt/server_args.py
    CHANGED
    
    | 
         @@ -39,6 +39,7 @@ from sglang.srt.utils.common import ( 
     | 
|
| 
       39 
39 
     | 
    
         
             
                get_device,
         
     | 
| 
       40 
40 
     | 
    
         
             
                get_device_memory_capacity,
         
     | 
| 
       41 
41 
     | 
    
         
             
                get_device_sm,
         
     | 
| 
      
 42 
     | 
    
         
            +
                is_blackwell_supported,
         
     | 
| 
       42 
43 
     | 
    
         
             
                is_cuda,
         
     | 
| 
       43 
44 
     | 
    
         
             
                is_fa3_default_architecture,
         
     | 
| 
       44 
45 
     | 
    
         
             
                is_flashinfer_available,
         
     | 
| 
         @@ -98,6 +99,7 @@ QUANTIZATION_CHOICES = [ 
     | 
|
| 
       98 
99 
     | 
    
         
             
                "qoq",
         
     | 
| 
       99 
100 
     | 
    
         
             
                "w4afp8",
         
     | 
| 
       100 
101 
     | 
    
         
             
                "mxfp4",
         
     | 
| 
      
 102 
     | 
    
         
            +
                "auto-round",
         
     | 
| 
       101 
103 
     | 
    
         
             
                "compressed-tensors",  # for Ktransformers
         
     | 
| 
       102 
104 
     | 
    
         
             
            ]
         
     | 
| 
       103 
105 
     | 
    
         | 
| 
         @@ -133,7 +135,18 @@ GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] 
     | 
|
| 
       133 
135 
     | 
    
         | 
| 
       134 
136 
     | 
    
         
             
            DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
         
     | 
| 
       135 
137 
     | 
    
         | 
| 
       136 
     | 
    
         
            -
             
     | 
| 
      
 138 
     | 
    
         
            +
            RADIX_SUPPORTED_DETERMINISTIC_ATTENTION_BACKEND = ["fa3", "triton"]
         
     | 
| 
      
 139 
     | 
    
         
            +
             
     | 
| 
      
 140 
     | 
    
         
            +
            DEFAULT_LORA_EVICTION_POLICY = "lru"
         
     | 
| 
      
 141 
     | 
    
         
            +
             
     | 
| 
      
 142 
     | 
    
         
            +
            NSA_CHOICES = [
         
     | 
| 
      
 143 
     | 
    
         
            +
                "flashmla_sparse",
         
     | 
| 
      
 144 
     | 
    
         
            +
                "flashmla_kv",
         
     | 
| 
      
 145 
     | 
    
         
            +
                "flashmla_auto",
         
     | 
| 
      
 146 
     | 
    
         
            +
                "fa3",
         
     | 
| 
      
 147 
     | 
    
         
            +
                "tilelang",
         
     | 
| 
      
 148 
     | 
    
         
            +
                "aiter",
         
     | 
| 
      
 149 
     | 
    
         
            +
            ]
         
     | 
| 
       137 
150 
     | 
    
         | 
| 
       138 
151 
     | 
    
         
             
            RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
         
     | 
| 
       139 
152 
     | 
    
         | 
| 
         @@ -179,6 +192,10 @@ def add_deterministic_attention_backend_choices(choices): 
     | 
|
| 
       179 
192 
     | 
    
         
             
                DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices)
         
     | 
| 
       180 
193 
     | 
    
         | 
| 
       181 
194 
     | 
    
         | 
| 
      
 195 
     | 
    
         
            +
            def add_radix_supported_deterministic_attention_backend_choices(choices):
         
     | 
| 
      
 196 
     | 
    
         
            +
                RADIX_SUPPORTED_DETERMINISTIC_ATTENTION_BACKEND.extend(choices)
         
     | 
| 
      
 197 
     | 
    
         
            +
             
     | 
| 
      
 198 
     | 
    
         
            +
             
     | 
| 
       182 
199 
     | 
    
         
             
            def add_radix_eviction_policy_choices(choices):
         
     | 
| 
       183 
200 
     | 
    
         
             
                RADIX_EVICTION_POLICY_CHOICES.extend(choices)
         
     | 
| 
       184 
201 
     | 
    
         | 
| 
         @@ -288,7 +305,7 @@ class ServerArgs: 
     | 
|
| 
       288 
305 
     | 
    
         
             
                enable_request_time_stats_logging: bool = False
         
     | 
| 
       289 
306 
     | 
    
         
             
                kv_events_config: Optional[str] = None
         
     | 
| 
       290 
307 
     | 
    
         
             
                enable_trace: bool = False
         
     | 
| 
       291 
     | 
    
         
            -
                 
     | 
| 
      
 308 
     | 
    
         
            +
                otlp_traces_endpoint: str = "localhost:4317"
         
     | 
| 
       292 
309 
     | 
    
         | 
| 
       293 
310 
     | 
    
         
             
                # API related
         
     | 
| 
       294 
311 
     | 
    
         
             
                api_key: Optional[str] = None
         
     | 
| 
         @@ -329,7 +346,7 @@ class ServerArgs: 
     | 
|
| 
       329 
346 
     | 
    
         
             
                max_loaded_loras: Optional[int] = None
         
     | 
| 
       330 
347 
     | 
    
         
             
                max_loras_per_batch: int = 8
         
     | 
| 
       331 
348 
     | 
    
         
             
                lora_eviction_policy: str = "lru"
         
     | 
| 
       332 
     | 
    
         
            -
                lora_backend: str = " 
     | 
| 
      
 349 
     | 
    
         
            +
                lora_backend: str = "csgmv"
         
     | 
| 
       333 
350 
     | 
    
         
             
                max_lora_chunk_size: Optional[int] = 16
         
     | 
| 
       334 
351 
     | 
    
         | 
| 
       335 
352 
     | 
    
         
             
                # Kernel backend
         
     | 
| 
         @@ -494,6 +511,9 @@ class ServerArgs: 
     | 
|
| 
       494 
511 
     | 
    
         | 
| 
       495 
512 
     | 
    
         
             
                # Debug tensor dumps
         
     | 
| 
       496 
513 
     | 
    
         
             
                debug_tensor_dump_output_folder: Optional[str] = None
         
     | 
| 
      
 514 
     | 
    
         
            +
                # -1 mean dump all layers.
         
     | 
| 
      
 515 
     | 
    
         
            +
                debug_tensor_dump_layers: int = -1
         
     | 
| 
      
 516 
     | 
    
         
            +
                # TODO(guoyuhong): clean the old dumper code.
         
     | 
| 
       497 
517 
     | 
    
         
             
                debug_tensor_dump_input_file: Optional[str] = None
         
     | 
| 
       498 
518 
     | 
    
         
             
                debug_tensor_dump_inject: bool = False
         
     | 
| 
       499 
519 
     | 
    
         | 
| 
         @@ -522,6 +542,10 @@ class ServerArgs: 
     | 
|
| 
       522 
542 
     | 
    
         
             
                pdmux_config_path: Optional[str] = None
         
     | 
| 
       523 
543 
     | 
    
         
             
                sm_group_num: int = 8
         
     | 
| 
       524 
544 
     | 
    
         | 
| 
      
 545 
     | 
    
         
            +
                # For Multi-Modal
         
     | 
| 
      
 546 
     | 
    
         
            +
                mm_max_concurrent_calls: int = 32
         
     | 
| 
      
 547 
     | 
    
         
            +
                mm_per_request_timeout: float = 10.0
         
     | 
| 
      
 548 
     | 
    
         
            +
             
     | 
| 
       525 
549 
     | 
    
         
             
                def __post_init__(self):
         
     | 
| 
       526 
550 
     | 
    
         
             
                    """
         
     | 
| 
       527 
551 
     | 
    
         
             
                    Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
         
     | 
| 
         @@ -811,7 +835,7 @@ class ServerArgs: 
     | 
|
| 
       811 
835 
     | 
    
         
             
                        capture_bs = (
         
     | 
| 
       812 
836 
     | 
    
         
             
                            list(range(1, 9, 1))
         
     | 
| 
       813 
837 
     | 
    
         
             
                            + list(range(10, 33, 2))
         
     | 
| 
       814 
     | 
    
         
            -
                            + list(range(40,  
     | 
| 
      
 838 
     | 
    
         
            +
                            + list(range(40, 65, 4))
         
     | 
| 
       815 
839 
     | 
    
         
             
                            + list(range(72, 257, 8))
         
     | 
| 
       816 
840 
     | 
    
         
             
                            + list(range(272, self.cuda_graph_max_bs + 1, 16))
         
     | 
| 
       817 
841 
     | 
    
         
             
                        )
         
     | 
| 
         @@ -874,7 +898,7 @@ class ServerArgs: 
     | 
|
| 
       874 
898 
     | 
    
         
             
                                logger.info(
         
     | 
| 
       875 
899 
     | 
    
         
             
                                    "Enable FlashInfer AllReduce Fusion on sm100 for DeepseekV3ForCausalLM"
         
     | 
| 
       876 
900 
     | 
    
         
             
                                )
         
     | 
| 
       877 
     | 
    
         
            -
                            if self.moe_runner_backend == "auto":
         
     | 
| 
      
 901 
     | 
    
         
            +
                            if self.moe_a2a_backend == "none" and self.moe_runner_backend == "auto":
         
     | 
| 
       878 
902 
     | 
    
         
             
                                self.moe_runner_backend = "flashinfer_trtllm"
         
     | 
| 
       879 
903 
     | 
    
         
             
                                logger.info(
         
     | 
| 
       880 
904 
     | 
    
         
             
                                    "Use flashinfer_trtllm as MoE runner backend on sm100 for DeepseekV3ForCausalLM"
         
     | 
| 
         @@ -912,7 +936,7 @@ class ServerArgs: 
     | 
|
| 
       912 
936 
     | 
    
         
             
                            f"- Decode: {decode_attn_backend}\n"
         
     | 
| 
       913 
937 
     | 
    
         
             
                        )
         
     | 
| 
       914 
938 
     | 
    
         | 
| 
       915 
     | 
    
         
            -
                        if  
     | 
| 
      
 939 
     | 
    
         
            +
                        if is_blackwell_supported():
         
     | 
| 
       916 
940 
     | 
    
         
             
                            if not self.enable_dp_attention:
         
     | 
| 
       917 
941 
     | 
    
         
             
                                self.enable_flashinfer_allreduce_fusion = True
         
     | 
| 
       918 
942 
     | 
    
         
             
                                logger.info(
         
     | 
| 
         @@ -924,7 +948,7 @@ class ServerArgs: 
     | 
|
| 
       924 
948 
     | 
    
         
             
                            and quantization_config.get("quant_method") == "mxfp4"
         
     | 
| 
       925 
949 
     | 
    
         
             
                        )
         
     | 
| 
       926 
950 
     | 
    
         | 
| 
       927 
     | 
    
         
            -
                        if  
     | 
| 
      
 951 
     | 
    
         
            +
                        if is_blackwell_supported() and is_mxfp4_quant_format:
         
     | 
| 
       928 
952 
     | 
    
         
             
                            self.moe_runner_backend = "flashinfer_mxfp4"
         
     | 
| 
       929 
953 
     | 
    
         
             
                            logger.warning(
         
     | 
| 
       930 
954 
     | 
    
         
             
                                "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
         
     | 
| 
         @@ -960,6 +984,12 @@ class ServerArgs: 
     | 
|
| 
       960 
984 
     | 
    
         
             
                            logger.warning(
         
     | 
| 
       961 
985 
     | 
    
         
             
                                "Use trtllm_mha as attention backend on sm100 for Llama4 model"
         
     | 
| 
       962 
986 
     | 
    
         
             
                            )
         
     | 
| 
      
 987 
     | 
    
         
            +
                        if is_sm100_supported() and self.moe_runner_backend == "auto":
         
     | 
| 
      
 988 
     | 
    
         
            +
                            if self.quantization in {"fp8", "modelopt_fp8"}:
         
     | 
| 
      
 989 
     | 
    
         
            +
                                self.moe_runner_backend = "flashinfer_trtllm"
         
     | 
| 
      
 990 
     | 
    
         
            +
                                logger.info(
         
     | 
| 
      
 991 
     | 
    
         
            +
                                    "Use flashinfer_trtllm as MoE runner backend on SM100 for Llama4"
         
     | 
| 
      
 992 
     | 
    
         
            +
                                )
         
     | 
| 
       963 
993 
     | 
    
         
             
                    elif model_arch in [
         
     | 
| 
       964 
994 
     | 
    
         
             
                        "Gemma2ForCausalLM",
         
     | 
| 
       965 
995 
     | 
    
         
             
                        "Gemma3ForCausalLM",
         
     | 
| 
         @@ -998,6 +1028,11 @@ class ServerArgs: 
     | 
|
| 
       998 
1028 
     | 
    
         
             
                        logger.info(
         
     | 
| 
       999 
1029 
     | 
    
         
             
                            f"Using {self.attention_backend} as attention backend for {model_arch}."
         
     | 
| 
       1000 
1030 
     | 
    
         
             
                        )
         
     | 
| 
      
 1031 
     | 
    
         
            +
                    elif model_arch in ["KimiLinearForCausalLM"]:
         
     | 
| 
      
 1032 
     | 
    
         
            +
                        logger.warning(
         
     | 
| 
      
 1033 
     | 
    
         
            +
                            f"Disabling Radix Cache for {model_arch} as it is not yet supported."
         
     | 
| 
      
 1034 
     | 
    
         
            +
                        )
         
     | 
| 
      
 1035 
     | 
    
         
            +
                        self.disable_radix_cache = True
         
     | 
| 
       1001 
1036 
     | 
    
         | 
| 
       1002 
1037 
     | 
    
         
             
                    if is_deepseek_nsa(hf_config):
         
     | 
| 
       1003 
1038 
     | 
    
         
             
                        if (
         
     | 
| 
         @@ -1020,16 +1055,30 @@ class ServerArgs: 
     | 
|
| 
       1020 
1055 
     | 
    
         
             
                            import torch
         
     | 
| 
       1021 
1056 
     | 
    
         | 
| 
       1022 
1057 
     | 
    
         
             
                            major, _ = torch.cuda.get_device_capability()
         
     | 
| 
       1023 
     | 
    
         
            -
                            if  
     | 
| 
       1024 
     | 
    
         
            -
                                self.kv_cache_dtype = "fp8_e4m3"
         
     | 
| 
       1025 
     | 
    
         
            -
                                logger.warning( 
     | 
| 
      
 1058 
     | 
    
         
            +
                            if self.kv_cache_dtype == "auto":
         
     | 
| 
      
 1059 
     | 
    
         
            +
                                self.kv_cache_dtype = "fp8_e4m3" if major >= 10 else "bfloat16"
         
     | 
| 
      
 1060 
     | 
    
         
            +
                                logger.warning(
         
     | 
| 
      
 1061 
     | 
    
         
            +
                                    f"Setting KV cache dtype to {self.kv_cache_dtype} for DeepSeek NSA."
         
     | 
| 
      
 1062 
     | 
    
         
            +
                                )
         
     | 
| 
      
 1063 
     | 
    
         
            +
                            if self.kv_cache_dtype == "bf16":
         
     | 
| 
      
 1064 
     | 
    
         
            +
                                self.kv_cache_dtype = "bfloat16"
         
     | 
| 
      
 1065 
     | 
    
         
            +
                            assert self.kv_cache_dtype in [
         
     | 
| 
      
 1066 
     | 
    
         
            +
                                "bfloat16",
         
     | 
| 
      
 1067 
     | 
    
         
            +
                                "fp8_e4m3",
         
     | 
| 
      
 1068 
     | 
    
         
            +
                            ], "DeepSeek NSA only supports bf16/bfloat16 or fp8_e4m3 kv_cache_dtype"
         
     | 
| 
       1026 
1069 
     | 
    
         | 
| 
       1027 
1070 
     | 
    
         
             
                            if self.kv_cache_dtype == "fp8_e4m3":
         
     | 
| 
       1028 
     | 
    
         
            -
                                 
     | 
| 
      
 1071 
     | 
    
         
            +
                                # flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics
         
     | 
| 
      
 1072 
     | 
    
         
            +
                                self.nsa_prefill_backend = "flashmla_auto"
         
     | 
| 
       1029 
1073 
     | 
    
         
             
                                self.nsa_decode_backend = "flashmla_kv"
         
     | 
| 
       1030 
1074 
     | 
    
         
             
                                logger.warning(
         
     | 
| 
       1031 
     | 
    
         
            -
                                    "Setting NSA backend to flashmla_kv for FP8 KV Cache."
         
     | 
| 
      
 1075 
     | 
    
         
            +
                                    "Setting NSA backend to flashmla_auto for prefill and flashmla_kv for decode for FP8 KV Cache."
         
     | 
| 
       1032 
1076 
     | 
    
         
             
                                )
         
     | 
| 
      
 1077 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 1078 
     | 
    
         
            +
                                # set prefill/decode backends for Blackwell. The default settings are for Hopper.
         
     | 
| 
      
 1079 
     | 
    
         
            +
                                if major >= 10:
         
     | 
| 
      
 1080 
     | 
    
         
            +
                                    self.nsa_prefill_backend = "flashmla_sparse"
         
     | 
| 
      
 1081 
     | 
    
         
            +
                                    self.nsa_decode_backend = "flashmla_sparse"
         
     | 
| 
       1033 
1082 
     | 
    
         | 
| 
       1034 
1083 
     | 
    
         
             
                            # Logging env vars for NSA
         
     | 
| 
       1035 
1084 
     | 
    
         
             
                            from sglang.srt.layers.attention.nsa.utils import (
         
     | 
| 
         @@ -1144,7 +1193,7 @@ class ServerArgs: 
     | 
|
| 
       1144 
1193 
     | 
    
         
             
                        self.attention_backend == "trtllm_mla"
         
     | 
| 
       1145 
1194 
     | 
    
         
             
                        or self.decode_attention_backend == "trtllm_mla"
         
     | 
| 
       1146 
1195 
     | 
    
         
             
                    ):
         
     | 
| 
       1147 
     | 
    
         
            -
                        if not  
     | 
| 
      
 1196 
     | 
    
         
            +
                        if not is_blackwell_supported():
         
     | 
| 
       1148 
1197 
     | 
    
         
             
                            raise ValueError(
         
     | 
| 
       1149 
1198 
     | 
    
         
             
                                "TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
         
     | 
| 
       1150 
1199 
     | 
    
         
             
                            )
         
     | 
| 
         @@ -1196,7 +1245,7 @@ class ServerArgs: 
     | 
|
| 
       1196 
1245 
     | 
    
         
             
                    # AMD platforms backends
         
     | 
| 
       1197 
1246 
     | 
    
         
             
                    if self.attention_backend == "aiter":
         
     | 
| 
       1198 
1247 
     | 
    
         
             
                        if model_config.context_len > 8192:
         
     | 
| 
       1199 
     | 
    
         
            -
                            self.mem_fraction_static *= 0. 
     | 
| 
      
 1248 
     | 
    
         
            +
                            self.mem_fraction_static *= 0.85
         
     | 
| 
       1200 
1249 
     | 
    
         | 
| 
       1201 
1250 
     | 
    
         
             
                    # NPU platforms backends
         
     | 
| 
       1202 
1251 
     | 
    
         
             
                    if is_npu() and self.attention_backend in ["ascend"]:
         
     | 
| 
         @@ -1311,8 +1360,10 @@ class ServerArgs: 
     | 
|
| 
       1311 
1360 
     | 
    
         | 
| 
       1312 
1361 
     | 
    
         
             
                    if self.moe_runner_backend == "flashinfer_trtllm":
         
     | 
| 
       1313 
1362 
     | 
    
         
             
                        assert (
         
     | 
| 
       1314 
     | 
    
         
            -
                            self.quantization == "modelopt_fp4" 
     | 
| 
       1315 
     | 
    
         
            -
             
     | 
| 
      
 1363 
     | 
    
         
            +
                            self.quantization == "modelopt_fp4"
         
     | 
| 
      
 1364 
     | 
    
         
            +
                            or self.quantization == "modelopt_fp8"
         
     | 
| 
      
 1365 
     | 
    
         
            +
                            or self.quantization == "fp8"
         
     | 
| 
      
 1366 
     | 
    
         
            +
                        ), "modelopt_fp4, modelopt_fp8 or fp8 quantization is required for Flashinfer TRTLLM MoE"
         
     | 
| 
       1316 
1367 
     | 
    
         
             
                        self.disable_shared_experts_fusion = True
         
     | 
| 
       1317 
1368 
     | 
    
         
             
                        logger.warning(
         
     | 
| 
       1318 
1369 
     | 
    
         
             
                            "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
         
     | 
| 
         @@ -1713,13 +1764,17 @@ class ServerArgs: 
     | 
|
| 
       1713 
1764 
     | 
    
         
             
                                f"but you explicitly specified '{self.attention_backend}'."
         
     | 
| 
       1714 
1765 
     | 
    
         
             
                            )
         
     | 
| 
       1715 
1766 
     | 
    
         | 
| 
       1716 
     | 
    
         
            -
                        if  
     | 
| 
       1717 
     | 
    
         
            -
                            if  
     | 
| 
      
 1767 
     | 
    
         
            +
                        if is_deepseek_model:
         
     | 
| 
      
 1768 
     | 
    
         
            +
                            if self.attention_backend not in ["fa3", "triton"]:
         
     | 
| 
       1718 
1769 
     | 
    
         
             
                                raise ValueError(
         
     | 
| 
       1719 
     | 
    
         
            -
                                    f"Currently only  
     | 
| 
      
 1770 
     | 
    
         
            +
                                    f"Currently only {RADIX_SUPPORTED_DETERMINISTIC_ATTENTION_BACKEND} attention backends are supported for deterministic inference with DeepSeek models. But you're using {self.attention_backend}."
         
     | 
| 
       1720 
1771 
     | 
    
         
             
                                )
         
     | 
| 
       1721 
1772 
     | 
    
         | 
| 
       1722 
     | 
    
         
            -
             
     | 
| 
      
 1773 
     | 
    
         
            +
                        if (
         
     | 
| 
      
 1774 
     | 
    
         
            +
                            self.attention_backend
         
     | 
| 
      
 1775 
     | 
    
         
            +
                            not in RADIX_SUPPORTED_DETERMINISTIC_ATTENTION_BACKEND
         
     | 
| 
      
 1776 
     | 
    
         
            +
                        ):
         
     | 
| 
      
 1777 
     | 
    
         
            +
                            # Currently, only certain backends support radix cache. Support for other backends is in progress
         
     | 
| 
       1723 
1778 
     | 
    
         
             
                            self.disable_radix_cache = True
         
     | 
| 
       1724 
1779 
     | 
    
         
             
                            logger.warning(
         
     | 
| 
       1725 
1780 
     | 
    
         
             
                                f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future."
         
     | 
| 
         @@ -1734,7 +1789,13 @@ class ServerArgs: 
     | 
|
| 
       1734 
1789 
     | 
    
         
             
                            )
         
     | 
| 
       1735 
1790 
     | 
    
         | 
| 
       1736 
1791 
     | 
    
         
             
                def _handle_other_validations(self):
         
     | 
| 
       1737 
     | 
    
         
            -
                     
     | 
| 
      
 1792 
     | 
    
         
            +
                    # Handle model inference tensor dump.
         
     | 
| 
      
 1793 
     | 
    
         
            +
                    if self.debug_tensor_dump_output_folder is not None:
         
     | 
| 
      
 1794 
     | 
    
         
            +
                        logger.warning(
         
     | 
| 
      
 1795 
     | 
    
         
            +
                            "Cuda graph and server warmup are disabled because of using tensor dump mode"
         
     | 
| 
      
 1796 
     | 
    
         
            +
                        )
         
     | 
| 
      
 1797 
     | 
    
         
            +
                        self.disable_cuda_graph = True
         
     | 
| 
      
 1798 
     | 
    
         
            +
                        self.skip_server_warmup = True
         
     | 
| 
       1738 
1799 
     | 
    
         | 
| 
       1739 
1800 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       1740 
1801 
     | 
    
         
             
                def add_cli_args(parser: argparse.ArgumentParser):
         
     | 
| 
         @@ -2315,7 +2376,7 @@ class ServerArgs: 
     | 
|
| 
       2315 
2376 
     | 
    
         
             
                        help="Enable opentelemetry trace",
         
     | 
| 
       2316 
2377 
     | 
    
         
             
                    )
         
     | 
| 
       2317 
2378 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       2318 
     | 
    
         
            -
                        "-- 
     | 
| 
      
 2379 
     | 
    
         
            +
                        "--otlp-traces-endpoint",
         
     | 
| 
       2319 
2380 
     | 
    
         
             
                        type=str,
         
     | 
| 
       2320 
2381 
     | 
    
         
             
                        default="localhost:4317",
         
     | 
| 
       2321 
2382 
     | 
    
         
             
                        help="Config opentelemetry collector endpoint if --enable-trace is set. format: <ip>:<port>",
         
     | 
| 
         @@ -3325,6 +3386,12 @@ class ServerArgs: 
     | 
|
| 
       3325 
3386 
     | 
    
         
             
                        default=ServerArgs.debug_tensor_dump_output_folder,
         
     | 
| 
       3326 
3387 
     | 
    
         
             
                        help="The output folder for dumping tensors.",
         
     | 
| 
       3327 
3388 
     | 
    
         
             
                    )
         
     | 
| 
      
 3389 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 3390 
     | 
    
         
            +
                        "--debug-tensor-dump-layers",
         
     | 
| 
      
 3391 
     | 
    
         
            +
                        type=int,
         
     | 
| 
      
 3392 
     | 
    
         
            +
                        default=-1,
         
     | 
| 
      
 3393 
     | 
    
         
            +
                        help="The layer number for dumping tensors.",
         
     | 
| 
      
 3394 
     | 
    
         
            +
                    )
         
     | 
| 
       3328 
3395 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       3329 
3396 
     | 
    
         
             
                        "--debug-tensor-dump-input-file",
         
     | 
| 
       3330 
3397 
     | 
    
         
             
                        type=str,
         
     | 
| 
         @@ -3461,6 +3528,20 @@ class ServerArgs: 
     | 
|
| 
       3461 
3528 
     | 
    
         
             
                        help="Read CLI options from a config file. Must be a YAML file with configuration options.",
         
     | 
| 
       3462 
3529 
     | 
    
         
             
                    )
         
     | 
| 
       3463 
3530 
     | 
    
         | 
| 
      
 3531 
     | 
    
         
            +
                    # For Multi-Modal
         
     | 
| 
      
 3532 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 3533 
     | 
    
         
            +
                        "--mm-max-concurrent-calls",
         
     | 
| 
      
 3534 
     | 
    
         
            +
                        type=int,
         
     | 
| 
      
 3535 
     | 
    
         
            +
                        default=ServerArgs.mm_max_concurrent_calls,
         
     | 
| 
      
 3536 
     | 
    
         
            +
                        help="The max concurrent calls for async mm data processing.",
         
     | 
| 
      
 3537 
     | 
    
         
            +
                    )
         
     | 
| 
      
 3538 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 3539 
     | 
    
         
            +
                        "--mm-per-request-timeout",
         
     | 
| 
      
 3540 
     | 
    
         
            +
                        type=int,
         
     | 
| 
      
 3541 
     | 
    
         
            +
                        default=ServerArgs.mm_per_request_timeout,
         
     | 
| 
      
 3542 
     | 
    
         
            +
                        help="The timeout for each multi-modal request in seconds.",
         
     | 
| 
      
 3543 
     | 
    
         
            +
                    )
         
     | 
| 
      
 3544 
     | 
    
         
            +
             
     | 
| 
       3464 
3545 
     | 
    
         
             
                @classmethod
         
     | 
| 
       3465 
3546 
     | 
    
         
             
                def from_cli_args(cls, args: argparse.Namespace):
         
     | 
| 
       3466 
3547 
     | 
    
         
             
                    args.tp_size = args.tensor_parallel_size
         
     | 
| 
         @@ -98,7 +98,10 @@ def execute_sbo( 
     | 
|
| 
       98 
98 
     | 
    
         
             
                    ):
         
     | 
| 
       99 
99 
     | 
    
         
             
                        forward_shared_experts()
         
     | 
| 
       100 
100 
     | 
    
         | 
| 
       101 
     | 
    
         
            -
                hidden_states = experts.dispatcher.combine( 
     | 
| 
      
 101 
     | 
    
         
            +
                hidden_states = experts.dispatcher.combine(
         
     | 
| 
      
 102 
     | 
    
         
            +
                    combine_input=combine_input,
         
     | 
| 
      
 103 
     | 
    
         
            +
                    overlap_args=combine_overlap_args,
         
     | 
| 
      
 104 
     | 
    
         
            +
                )
         
     | 
| 
       102 
105 
     | 
    
         | 
| 
       103 
106 
     | 
    
         
             
                return hidden_states
         
     | 
| 
       104 
107 
     | 
    
         | 
| 
         @@ -49,6 +49,7 @@ class DraftBackendFactory: 
     | 
|
| 
       49 
49 
     | 
    
         
             
                        "trtllm_mha": self._create_trtllm_mha_decode_backend,
         
     | 
| 
       50 
50 
     | 
    
         
             
                        "trtllm_mla": self._create_trtllm_mla_decode_backend,
         
     | 
| 
       51 
51 
     | 
    
         
             
                        "nsa": self._create_nsa_decode_backend,
         
     | 
| 
      
 52 
     | 
    
         
            +
                        "ascend": self._create_ascend_decode_backend,
         
     | 
| 
       52 
53 
     | 
    
         
             
                    }
         
     | 
| 
       53 
54 
     | 
    
         | 
| 
       54 
55 
     | 
    
         
             
                    return self._create_backend(
         
     | 
| 
         @@ -72,6 +73,7 @@ class DraftBackendFactory: 
     | 
|
| 
       72 
73 
     | 
    
         
             
                        "trtllm_mha": self._create_trtllm_mha_prefill_backend,
         
     | 
| 
       73 
74 
     | 
    
         
             
                        "trtllm_mla": self._create_trtllm_mla_prefill_backend,
         
     | 
| 
       74 
75 
     | 
    
         
             
                        "nsa": self._create_nsa_prefill_backend,
         
     | 
| 
      
 76 
     | 
    
         
            +
                        "ascend": self._create_ascend_prefill_backend,
         
     | 
| 
       75 
77 
     | 
    
         
             
                    }
         
     | 
| 
       76 
78 
     | 
    
         
             
                    backend_name = (
         
     | 
| 
       77 
79 
     | 
    
         
             
                        "decode_attention_backend"
         
     | 
| 
         @@ -173,6 +175,15 @@ class DraftBackendFactory: 
     | 
|
| 
       173 
175 
     | 
    
         
             
                        self.draft_model_runner, self.topk, self.speculative_num_steps
         
     | 
| 
       174 
176 
     | 
    
         
             
                    )
         
     | 
| 
       175 
177 
     | 
    
         | 
| 
      
 178 
     | 
    
         
            +
                def _create_ascend_decode_backend(self):
         
     | 
| 
      
 179 
     | 
    
         
            +
                    from sglang.srt.layers.attention.ascend_backend import (
         
     | 
| 
      
 180 
     | 
    
         
            +
                        AscendAttnMultiStepDraftBackend,
         
     | 
| 
      
 181 
     | 
    
         
            +
                    )
         
     | 
| 
      
 182 
     | 
    
         
            +
             
     | 
| 
      
 183 
     | 
    
         
            +
                    return AscendAttnMultiStepDraftBackend(
         
     | 
| 
      
 184 
     | 
    
         
            +
                        self.draft_model_runner, self.topk, self.speculative_num_steps
         
     | 
| 
      
 185 
     | 
    
         
            +
                    )
         
     | 
| 
      
 186 
     | 
    
         
            +
             
     | 
| 
       176 
187 
     | 
    
         
             
                def _create_flashinfer_prefill_backend(self):
         
     | 
| 
       177 
188 
     | 
    
         
             
                    if not get_global_server_args().use_mla_backend:
         
     | 
| 
       178 
189 
     | 
    
         
             
                        from sglang.srt.layers.attention.flashinfer_backend import (
         
     | 
| 
         @@ -219,6 +230,11 @@ class DraftBackendFactory: 
     | 
|
| 
       219 
230 
     | 
    
         | 
| 
       220 
231 
     | 
    
         
             
                    return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
         
     | 
| 
       221 
232 
     | 
    
         | 
| 
      
 233 
     | 
    
         
            +
                def _create_ascend_prefill_backend(self):
         
     | 
| 
      
 234 
     | 
    
         
            +
                    from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
         
     | 
| 
      
 235 
     | 
    
         
            +
             
     | 
| 
      
 236 
     | 
    
         
            +
                    return AscendAttnBackend(self.draft_model_runner)
         
     | 
| 
      
 237 
     | 
    
         
            +
             
     | 
| 
       222 
238 
     | 
    
         
             
                def _create_flashmla_prefill_backend(self):
         
     | 
| 
       223 
239 
     | 
    
         
             
                    logger.warning(
         
     | 
| 
       224 
240 
     | 
    
         
             
                        "flashmla prefill backend is not yet supported for draft extend."
         
     |