sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +18 -3
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +25 -2
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -5
 - sglang/srt/entrypoints/engine.py +13 -5
 - sglang/srt/entrypoints/http_server.py +22 -3
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +7 -0
 - sglang/srt/eplb/expert_distribution.py +34 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +7 -6
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
 - sglang/srt/layers/communicator.py +23 -1
 - sglang/srt/layers/layernorm.py +16 -2
 - sglang/srt/layers/logits_processor.py +4 -20
 - sglang/srt/layers/moe/ep_moe/layer.py +0 -18
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
 - sglang/srt/layers/moe/topk.py +31 -6
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +9 -78
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/rotary_embedding.py +117 -45
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +26 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +164 -129
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +154 -59
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/memory_pool.py +171 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +11 -11
 - sglang/srt/model_executor/model_runner.py +76 -21
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +149 -34
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +0 -1
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +1 -1
 - sglang/srt/models/qwen2.py +1 -1
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +1 -1
 - sglang/srt/models/qwen3_moe.py +16 -8
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +103 -22
 - sglang/srt/single_batch_overlap.py +4 -1
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +55 -32
 - sglang/srt/utils/hf_transformers_utils.py +38 -16
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -17,7 +17,7 @@ from __future__ import annotations 
     | 
|
| 
       17 
17 
     | 
    
         | 
| 
       18 
18 
     | 
    
         
             
            from dataclasses import dataclass
         
     | 
| 
       19 
19 
     | 
    
         | 
| 
       20 
     | 
    
         
            -
            from sglang.srt.configs.mamba_utils import Mamba2CacheParams
         
     | 
| 
      
 20 
     | 
    
         
            +
            from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, Mamba2CacheParams
         
     | 
| 
       21 
21 
     | 
    
         
             
            from sglang.srt.layers.attention.nsa import index_buf_accessor
         
     | 
| 
       22 
22 
     | 
    
         
             
            from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
         
     | 
| 
       23 
23 
     | 
    
         
             
            from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
         
     | 
| 
         @@ -33,7 +33,7 @@ KVCache actually holds the physical kv cache. 
     | 
|
| 
       33 
33 
     | 
    
         | 
| 
       34 
34 
     | 
    
         
             
            import abc
         
     | 
| 
       35 
35 
     | 
    
         
             
            import logging
         
     | 
| 
       36 
     | 
    
         
            -
            from contextlib import nullcontext
         
     | 
| 
      
 36 
     | 
    
         
            +
            from contextlib import contextmanager, nullcontext
         
     | 
| 
       37 
37 
     | 
    
         
             
            from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
         
     | 
| 
       38 
38 
     | 
    
         | 
| 
       39 
39 
     | 
    
         
             
            import numpy as np
         
     | 
| 
         @@ -59,7 +59,9 @@ if _is_npu: 
     | 
|
| 
       59 
59 
     | 
    
         
             
                import torch_npu
         
     | 
| 
       60 
60 
     | 
    
         | 
| 
       61 
61 
     | 
    
         | 
| 
       62 
     | 
    
         
            -
            def get_tensor_size_bytes(t: torch.Tensor):
         
     | 
| 
      
 62 
     | 
    
         
            +
            def get_tensor_size_bytes(t: Union[torch.Tensor, List[torch.Tensor]]):
         
     | 
| 
      
 63 
     | 
    
         
            +
                if isinstance(t, list):
         
     | 
| 
      
 64 
     | 
    
         
            +
                    return sum(get_tensor_size_bytes(x) for x in t)
         
     | 
| 
       63 
65 
     | 
    
         
             
                return np.prod(t.shape) * t.dtype.itemsize
         
     | 
| 
       64 
66 
     | 
    
         | 
| 
       65 
67 
     | 
    
         | 
| 
         @@ -116,10 +118,15 @@ class ReqToTokenPool: 
     | 
|
| 
       116 
118 
     | 
    
         
             
            class MambaPool:
         
     | 
| 
       117 
119 
     | 
    
         
             
                @dataclass(frozen=True, kw_only=True)
         
     | 
| 
       118 
120 
     | 
    
         
             
                class State:
         
     | 
| 
       119 
     | 
    
         
            -
                    conv: torch.Tensor
         
     | 
| 
      
 121 
     | 
    
         
            +
                    conv: Union[torch.Tensor, List[torch.Tensor]]
         
     | 
| 
       120 
122 
     | 
    
         
             
                    temporal: torch.Tensor
         
     | 
| 
       121 
123 
     | 
    
         | 
| 
       122 
124 
     | 
    
         
             
                    def at_layer_idx(self, layer: int):
         
     | 
| 
      
 125 
     | 
    
         
            +
                        if isinstance(self.conv, list):
         
     | 
| 
      
 126 
     | 
    
         
            +
                            return type(self)(
         
     | 
| 
      
 127 
     | 
    
         
            +
                                conv=[v[layer] for v in self.conv],
         
     | 
| 
      
 128 
     | 
    
         
            +
                                temporal=self.temporal[layer],
         
     | 
| 
      
 129 
     | 
    
         
            +
                            )
         
     | 
| 
       123 
130 
     | 
    
         
             
                        return type(self)(**{k: v[layer] for k, v in vars(self).items()})
         
     | 
| 
       124 
131 
     | 
    
         | 
| 
       125 
132 
     | 
    
         
             
                    def mem_usage_bytes(self):
         
     | 
| 
         @@ -127,14 +134,14 @@ class MambaPool: 
     | 
|
| 
       127 
134 
     | 
    
         | 
| 
       128 
135 
     | 
    
         
             
                @dataclass(frozen=True, kw_only=True)
         
     | 
| 
       129 
136 
     | 
    
         
             
                class SpeculativeState(State):
         
     | 
| 
       130 
     | 
    
         
            -
                    intermediate_ssm: torch.Tensor
         
     | 
| 
      
 137 
     | 
    
         
            +
                    intermediate_ssm: Union[torch.Tensor, List[torch.Tensor]]
         
     | 
| 
       131 
138 
     | 
    
         
             
                    intermediate_conv_window: torch.Tensor
         
     | 
| 
       132 
139 
     | 
    
         | 
| 
       133 
140 
     | 
    
         
             
                def __init__(
         
     | 
| 
       134 
141 
     | 
    
         
             
                    self,
         
     | 
| 
       135 
142 
     | 
    
         
             
                    *,
         
     | 
| 
       136 
143 
     | 
    
         
             
                    size: int,
         
     | 
| 
       137 
     | 
    
         
            -
                    cache_params: "Mamba2CacheParams",
         
     | 
| 
      
 144 
     | 
    
         
            +
                    cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
         
     | 
| 
       138 
145 
     | 
    
         
             
                    device: str,
         
     | 
| 
       139 
146 
     | 
    
         
             
                    speculative_num_draft_tokens: Optional[int] = None,
         
     | 
| 
       140 
147 
     | 
    
         
             
                ):
         
     | 
| 
         @@ -157,18 +164,29 @@ class MambaPool: 
     | 
|
| 
       157 
164 
     | 
    
         
             
                    else:
         
     | 
| 
       158 
165 
     | 
    
         
             
                        self.custom_mem_pool = None
         
     | 
| 
       159 
166 
     | 
    
         | 
| 
      
 167 
     | 
    
         
            +
                    self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams)
         
     | 
| 
       160 
168 
     | 
    
         
             
                    with (
         
     | 
| 
       161 
169 
     | 
    
         
             
                        torch.cuda.use_mem_pool(self.custom_mem_pool)
         
     | 
| 
       162 
170 
     | 
    
         
             
                        if self.enable_custom_mem_pool
         
     | 
| 
       163 
171 
     | 
    
         
             
                        else nullcontext()
         
     | 
| 
       164 
172 
     | 
    
         
             
                    ):
         
     | 
| 
       165 
     | 
    
         
            -
                         
     | 
| 
       166 
     | 
    
         
            -
             
     | 
| 
       167 
     | 
    
         
            -
             
     | 
| 
       168 
     | 
    
         
            -
             
     | 
| 
       169 
     | 
    
         
            -
             
     | 
| 
       170 
     | 
    
         
            -
             
     | 
| 
       171 
     | 
    
         
            -
             
     | 
| 
      
 173 
     | 
    
         
            +
                        if self.is_kda_cache:
         
     | 
| 
      
 174 
     | 
    
         
            +
                            conv_state = [
         
     | 
| 
      
 175 
     | 
    
         
            +
                                torch.zeros(
         
     | 
| 
      
 176 
     | 
    
         
            +
                                    size=(num_mamba_layers, size + 1) + conv_shape,
         
     | 
| 
      
 177 
     | 
    
         
            +
                                    dtype=conv_dtype,
         
     | 
| 
      
 178 
     | 
    
         
            +
                                    device=device,
         
     | 
| 
      
 179 
     | 
    
         
            +
                                )
         
     | 
| 
      
 180 
     | 
    
         
            +
                                for conv_shape in conv_state_shape
         
     | 
| 
      
 181 
     | 
    
         
            +
                            ]
         
     | 
| 
      
 182 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 183 
     | 
    
         
            +
                            # assume conv_state = (dim, state_len)
         
     | 
| 
      
 184 
     | 
    
         
            +
                            assert conv_state_shape[0] > conv_state_shape[1]
         
     | 
| 
      
 185 
     | 
    
         
            +
                            conv_state = torch.zeros(
         
     | 
| 
      
 186 
     | 
    
         
            +
                                size=(num_mamba_layers, size + 1) + conv_state_shape,
         
     | 
| 
      
 187 
     | 
    
         
            +
                                dtype=conv_dtype,
         
     | 
| 
      
 188 
     | 
    
         
            +
                                device=device,
         
     | 
| 
      
 189 
     | 
    
         
            +
                            )
         
     | 
| 
       172 
190 
     | 
    
         
             
                        temporal_state = torch.zeros(
         
     | 
| 
       173 
191 
     | 
    
         
             
                            size=(num_mamba_layers, size + 1) + temporal_state_shape,
         
     | 
| 
       174 
192 
     | 
    
         
             
                            dtype=ssm_dtype,
         
     | 
| 
         @@ -191,17 +209,34 @@ class MambaPool: 
     | 
|
| 
       191 
209 
     | 
    
         
             
                            )
         
     | 
| 
       192 
210 
     | 
    
         
             
                            # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
         
     | 
| 
       193 
211 
     | 
    
         
             
                            # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
         
     | 
| 
       194 
     | 
    
         
            -
             
     | 
| 
       195 
     | 
    
         
            -
             
     | 
| 
       196 
     | 
    
         
            -
             
     | 
| 
       197 
     | 
    
         
            -
                                     
     | 
| 
       198 
     | 
    
         
            -
             
     | 
| 
       199 
     | 
    
         
            -
             
     | 
| 
       200 
     | 
    
         
            -
             
     | 
| 
       201 
     | 
    
         
            -
             
     | 
| 
       202 
     | 
    
         
            -
             
     | 
| 
       203 
     | 
    
         
            -
             
     | 
| 
       204 
     | 
    
         
            -
             
     | 
| 
      
 212 
     | 
    
         
            +
             
     | 
| 
      
 213 
     | 
    
         
            +
                            if self.is_kda_cache:
         
     | 
| 
      
 214 
     | 
    
         
            +
                                intermediate_conv_window_cache = [
         
     | 
| 
      
 215 
     | 
    
         
            +
                                    torch.zeros(
         
     | 
| 
      
 216 
     | 
    
         
            +
                                        size=(
         
     | 
| 
      
 217 
     | 
    
         
            +
                                            num_mamba_layers,
         
     | 
| 
      
 218 
     | 
    
         
            +
                                            size + 1,
         
     | 
| 
      
 219 
     | 
    
         
            +
                                            speculative_num_draft_tokens,
         
     | 
| 
      
 220 
     | 
    
         
            +
                                            conv_shape[0],
         
     | 
| 
      
 221 
     | 
    
         
            +
                                            conv_shape[1],
         
     | 
| 
      
 222 
     | 
    
         
            +
                                        ),
         
     | 
| 
      
 223 
     | 
    
         
            +
                                        dtype=conv_dtype,
         
     | 
| 
      
 224 
     | 
    
         
            +
                                        device="cuda",
         
     | 
| 
      
 225 
     | 
    
         
            +
                                    )
         
     | 
| 
      
 226 
     | 
    
         
            +
                                    for conv_shape in conv_state_shape
         
     | 
| 
      
 227 
     | 
    
         
            +
                                ]
         
     | 
| 
      
 228 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 229 
     | 
    
         
            +
                                intermediate_conv_window_cache = torch.zeros(
         
     | 
| 
      
 230 
     | 
    
         
            +
                                    size=(
         
     | 
| 
      
 231 
     | 
    
         
            +
                                        num_mamba_layers,
         
     | 
| 
      
 232 
     | 
    
         
            +
                                        size + 1,
         
     | 
| 
      
 233 
     | 
    
         
            +
                                        speculative_num_draft_tokens,
         
     | 
| 
      
 234 
     | 
    
         
            +
                                        conv_state_shape[0],
         
     | 
| 
      
 235 
     | 
    
         
            +
                                        conv_state_shape[1],
         
     | 
| 
      
 236 
     | 
    
         
            +
                                    ),
         
     | 
| 
      
 237 
     | 
    
         
            +
                                    dtype=conv_dtype,
         
     | 
| 
      
 238 
     | 
    
         
            +
                                    device="cuda",
         
     | 
| 
      
 239 
     | 
    
         
            +
                                )
         
     | 
| 
       205 
240 
     | 
    
         
             
                            self.mamba_cache = self.SpeculativeState(
         
     | 
| 
       206 
241 
     | 
    
         
             
                                conv=conv_state,
         
     | 
| 
       207 
242 
     | 
    
         
             
                                temporal=temporal_state,
         
     | 
| 
         @@ -255,15 +290,25 @@ class MambaPool: 
     | 
|
| 
       255 
290 
     | 
    
         
             
                    if free_index.numel() == 0:
         
     | 
| 
       256 
291 
     | 
    
         
             
                        return
         
     | 
| 
       257 
292 
     | 
    
         
             
                    self.free_slots = torch.cat((self.free_slots, free_index))
         
     | 
| 
       258 
     | 
    
         
            -
                     
     | 
| 
       259 
     | 
    
         
            -
                         
     | 
| 
       260 
     | 
    
         
            -
             
     | 
| 
      
 293 
     | 
    
         
            +
                    if self.is_kda_cache:
         
     | 
| 
      
 294 
     | 
    
         
            +
                        for i in range(len(self.mamba_cache.conv)):
         
     | 
| 
      
 295 
     | 
    
         
            +
                            self.mamba_cache.conv[i][:, free_index] = 0
         
     | 
| 
      
 296 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 297 
     | 
    
         
            +
                        self.mamba_cache.conv[:, free_index] = 0
         
     | 
| 
      
 298 
     | 
    
         
            +
                    self.mamba_cache.temporal[:, free_index] = 0
         
     | 
| 
       261 
299 
     | 
    
         | 
| 
       262 
300 
     | 
    
         
             
                def clear(self):
         
     | 
| 
       263 
301 
     | 
    
         
             
                    self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
         
     | 
| 
       264 
302 
     | 
    
         | 
| 
       265 
303 
     | 
    
         
             
                def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
         
     | 
| 
       266 
     | 
    
         
            -
                     
     | 
| 
      
 304 
     | 
    
         
            +
                    if self.is_kda_cache:
         
     | 
| 
      
 305 
     | 
    
         
            +
                        for i in range(len(self.mamba_cache.conv)):
         
     | 
| 
      
 306 
     | 
    
         
            +
                            self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][
         
     | 
| 
      
 307 
     | 
    
         
            +
                                :, src_index
         
     | 
| 
      
 308 
     | 
    
         
            +
                            ]
         
     | 
| 
      
 309 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 310 
     | 
    
         
            +
                        self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
         
     | 
| 
      
 311 
     | 
    
         
            +
             
     | 
| 
       267 
312 
     | 
    
         
             
                    self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
         
     | 
| 
       268 
313 
     | 
    
         
             
                        :, src_index
         
     | 
| 
       269 
314 
     | 
    
         
             
                    ]
         
     | 
| 
         @@ -304,7 +349,7 @@ class HybridReqToTokenPool(ReqToTokenPool): 
     | 
|
| 
       304 
349 
     | 
    
         
             
                    max_context_len: int,
         
     | 
| 
       305 
350 
     | 
    
         
             
                    device: str,
         
     | 
| 
       306 
351 
     | 
    
         
             
                    enable_memory_saver: bool,
         
     | 
| 
       307 
     | 
    
         
            -
                    cache_params: "Mamba2CacheParams",
         
     | 
| 
      
 352 
     | 
    
         
            +
                    cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
         
     | 
| 
       308 
353 
     | 
    
         
             
                    speculative_num_draft_tokens: int = None,
         
     | 
| 
       309 
354 
     | 
    
         
             
                ):
         
     | 
| 
       310 
355 
     | 
    
         
             
                    super().__init__(
         
     | 
| 
         @@ -323,7 +368,7 @@ class HybridReqToTokenPool(ReqToTokenPool): 
     | 
|
| 
       323 
368 
     | 
    
         
             
                def _init_mamba_pool(
         
     | 
| 
       324 
369 
     | 
    
         
             
                    self,
         
     | 
| 
       325 
370 
     | 
    
         
             
                    size: int,
         
     | 
| 
       326 
     | 
    
         
            -
                    cache_params: "Mamba2CacheParams",
         
     | 
| 
      
 371 
     | 
    
         
            +
                    cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
         
     | 
| 
       327 
372 
     | 
    
         
             
                    device: str,
         
     | 
| 
       328 
373 
     | 
    
         
             
                    speculative_num_draft_tokens: int = None,
         
     | 
| 
       329 
374 
     | 
    
         
             
                ):
         
     | 
| 
         @@ -509,6 +554,7 @@ class MHATokenToKVPool(KVCache): 
     | 
|
| 
       509 
554 
     | 
    
         
             
                    enable_memory_saver: bool,
         
     | 
| 
       510 
555 
     | 
    
         
             
                    start_layer: Optional[int] = None,
         
     | 
| 
       511 
556 
     | 
    
         
             
                    end_layer: Optional[int] = None,
         
     | 
| 
      
 557 
     | 
    
         
            +
                    enable_alt_stream: bool = True,
         
     | 
| 
       512 
558 
     | 
    
         
             
                    enable_kv_cache_copy: bool = False,
         
     | 
| 
       513 
559 
     | 
    
         
             
                ):
         
     | 
| 
       514 
560 
     | 
    
         
             
                    super().__init__(
         
     | 
| 
         @@ -527,7 +573,9 @@ class MHATokenToKVPool(KVCache): 
     | 
|
| 
       527 
573 
     | 
    
         
             
                    self._create_buffers()
         
     | 
| 
       528 
574 
     | 
    
         | 
| 
       529 
575 
     | 
    
         
             
                    self.device_module = torch.get_device_module(self.device)
         
     | 
| 
       530 
     | 
    
         
            -
                    self.alt_stream =  
     | 
| 
      
 576 
     | 
    
         
            +
                    self.alt_stream = (
         
     | 
| 
      
 577 
     | 
    
         
            +
                        self.device_module.Stream() if _is_cuda and enable_alt_stream else None
         
     | 
| 
      
 578 
     | 
    
         
            +
                    )
         
     | 
| 
       531 
579 
     | 
    
         | 
| 
       532 
580 
     | 
    
         
             
                    if enable_kv_cache_copy:
         
     | 
| 
       533 
581 
     | 
    
         
             
                        self._init_kv_copy_and_warmup()
         
     | 
| 
         @@ -809,6 +857,10 @@ class HybridLinearKVPool(KVCache): 
     | 
|
| 
       809 
857 
     | 
    
         
             
                    enable_kvcache_transpose: bool,
         
     | 
| 
       810 
858 
     | 
    
         
             
                    device: str,
         
     | 
| 
       811 
859 
     | 
    
         
             
                    mamba_pool: MambaPool,
         
     | 
| 
      
 860 
     | 
    
         
            +
                    # TODO: refactor mla related args
         
     | 
| 
      
 861 
     | 
    
         
            +
                    use_mla: bool = False,
         
     | 
| 
      
 862 
     | 
    
         
            +
                    kv_lora_rank: int = None,
         
     | 
| 
      
 863 
     | 
    
         
            +
                    qk_rope_head_dim: int = None,
         
     | 
| 
       812 
864 
     | 
    
         
             
                ):
         
     | 
| 
       813 
865 
     | 
    
         
             
                    self.size = size
         
     | 
| 
       814 
866 
     | 
    
         
             
                    self.dtype = dtype
         
     | 
| 
         @@ -822,25 +874,42 @@ class HybridLinearKVPool(KVCache): 
     | 
|
| 
       822 
874 
     | 
    
         
             
                    self.mamba_pool = mamba_pool
         
     | 
| 
       823 
875 
     | 
    
         
             
                    # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
         
     | 
| 
       824 
876 
     | 
    
         
             
                    assert not enable_kvcache_transpose
         
     | 
| 
       825 
     | 
    
         
            -
                     
     | 
| 
       826 
     | 
    
         
            -
             
     | 
| 
      
 877 
     | 
    
         
            +
                    self.use_mla = use_mla
         
     | 
| 
      
 878 
     | 
    
         
            +
                    if not use_mla:
         
     | 
| 
      
 879 
     | 
    
         
            +
                        if _is_npu:
         
     | 
| 
      
 880 
     | 
    
         
            +
                            TokenToKVPoolClass = AscendTokenToKVPool
         
     | 
| 
      
 881 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 882 
     | 
    
         
            +
                            TokenToKVPoolClass = MHATokenToKVPool
         
     | 
| 
      
 883 
     | 
    
         
            +
                        self.full_kv_pool = TokenToKVPoolClass(
         
     | 
| 
      
 884 
     | 
    
         
            +
                            size=size,
         
     | 
| 
      
 885 
     | 
    
         
            +
                            page_size=self.page_size,
         
     | 
| 
      
 886 
     | 
    
         
            +
                            dtype=dtype,
         
     | 
| 
      
 887 
     | 
    
         
            +
                            head_num=head_num,
         
     | 
| 
      
 888 
     | 
    
         
            +
                            head_dim=head_dim,
         
     | 
| 
      
 889 
     | 
    
         
            +
                            layer_num=self.full_layer_nums,
         
     | 
| 
      
 890 
     | 
    
         
            +
                            device=device,
         
     | 
| 
      
 891 
     | 
    
         
            +
                            enable_memory_saver=False,
         
     | 
| 
      
 892 
     | 
    
         
            +
                        )
         
     | 
| 
       827 
893 
     | 
    
         
             
                    else:
         
     | 
| 
       828 
     | 
    
         
            -
                        TokenToKVPoolClass =  
     | 
| 
       829 
     | 
    
         
            -
             
     | 
| 
       830 
     | 
    
         
            -
             
     | 
| 
       831 
     | 
    
         
            -
             
     | 
| 
       832 
     | 
    
         
            -
             
     | 
| 
       833 
     | 
    
         
            -
             
     | 
| 
       834 
     | 
    
         
            -
             
     | 
| 
       835 
     | 
    
         
            -
             
     | 
| 
       836 
     | 
    
         
            -
             
     | 
| 
       837 
     | 
    
         
            -
             
     | 
| 
       838 
     | 
    
         
            -
             
     | 
| 
      
 894 
     | 
    
         
            +
                        TokenToKVPoolClass = MLATokenToKVPool
         
     | 
| 
      
 895 
     | 
    
         
            +
                        self.full_kv_pool = TokenToKVPoolClass(
         
     | 
| 
      
 896 
     | 
    
         
            +
                            size=size,
         
     | 
| 
      
 897 
     | 
    
         
            +
                            page_size=self.page_size,
         
     | 
| 
      
 898 
     | 
    
         
            +
                            dtype=dtype,
         
     | 
| 
      
 899 
     | 
    
         
            +
                            layer_num=self.full_layer_nums,
         
     | 
| 
      
 900 
     | 
    
         
            +
                            device=device,
         
     | 
| 
      
 901 
     | 
    
         
            +
                            kv_lora_rank=kv_lora_rank,
         
     | 
| 
      
 902 
     | 
    
         
            +
                            qk_rope_head_dim=qk_rope_head_dim,
         
     | 
| 
      
 903 
     | 
    
         
            +
                            enable_memory_saver=False,
         
     | 
| 
      
 904 
     | 
    
         
            +
                        )
         
     | 
| 
       839 
905 
     | 
    
         
             
                    self.full_attention_layer_id_mapping = {
         
     | 
| 
       840 
906 
     | 
    
         
             
                        id: i for i, id in enumerate(full_attention_layer_ids)
         
     | 
| 
       841 
907 
     | 
    
         
             
                    }
         
     | 
| 
       842 
     | 
    
         
            -
                     
     | 
| 
       843 
     | 
    
         
            -
             
     | 
| 
      
 908 
     | 
    
         
            +
                    if use_mla:
         
     | 
| 
      
 909 
     | 
    
         
            +
                        self.mem_usage = self.get_kv_size_bytes() / GB
         
     | 
| 
      
 910 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 911 
     | 
    
         
            +
                        k_size, v_size = self.get_kv_size_bytes()
         
     | 
| 
      
 912 
     | 
    
         
            +
                        self.mem_usage = (k_size + v_size) / GB
         
     | 
| 
       844 
913 
     | 
    
         | 
| 
       845 
914 
     | 
    
         
             
                def get_kv_size_bytes(self):
         
     | 
| 
       846 
915 
     | 
    
         
             
                    return self.full_kv_pool.get_kv_size_bytes()
         
     | 
| 
         @@ -876,6 +945,21 @@ class HybridLinearKVPool(KVCache): 
     | 
|
| 
       876 
945 
     | 
    
         
             
                    layer_id = self._transfer_full_attention_id(layer_id)
         
     | 
| 
       877 
946 
     | 
    
         
             
                    return self.full_kv_pool.get_kv_buffer(layer_id)
         
     | 
| 
       878 
947 
     | 
    
         | 
| 
      
 948 
     | 
    
         
            +
                @contextmanager
         
     | 
| 
      
 949 
     | 
    
         
            +
                def _transfer_id_context(self, layer: RadixAttention):
         
     | 
| 
      
 950 
     | 
    
         
            +
             
     | 
| 
      
 951 
     | 
    
         
            +
                    @contextmanager
         
     | 
| 
      
 952 
     | 
    
         
            +
                    def _patch_layer_id(layer):
         
     | 
| 
      
 953 
     | 
    
         
            +
                        original_layer_id = layer.layer_id
         
     | 
| 
      
 954 
     | 
    
         
            +
                        layer.layer_id = self._transfer_full_attention_id(layer.layer_id)
         
     | 
| 
      
 955 
     | 
    
         
            +
                        try:
         
     | 
| 
      
 956 
     | 
    
         
            +
                            yield
         
     | 
| 
      
 957 
     | 
    
         
            +
                        finally:
         
     | 
| 
      
 958 
     | 
    
         
            +
                            layer.layer_id = original_layer_id
         
     | 
| 
      
 959 
     | 
    
         
            +
             
     | 
| 
      
 960 
     | 
    
         
            +
                    with _patch_layer_id(layer):
         
     | 
| 
      
 961 
     | 
    
         
            +
                        yield
         
     | 
| 
      
 962 
     | 
    
         
            +
             
     | 
| 
       879 
963 
     | 
    
         
             
                def set_kv_buffer(
         
     | 
| 
       880 
964 
     | 
    
         
             
                    self,
         
     | 
| 
       881 
965 
     | 
    
         
             
                    layer: RadixAttention,
         
     | 
| 
         @@ -886,19 +970,49 @@ class HybridLinearKVPool(KVCache): 
     | 
|
| 
       886 
970 
     | 
    
         
             
                    v_scale: float = 1.0,
         
     | 
| 
       887 
971 
     | 
    
         
             
                ):
         
     | 
| 
       888 
972 
     | 
    
         
             
                    layer_id = self._transfer_full_attention_id(layer.layer_id)
         
     | 
| 
       889 
     | 
    
         
            -
                    self. 
     | 
| 
       890 
     | 
    
         
            -
                         
     | 
| 
       891 
     | 
    
         
            -
             
     | 
| 
       892 
     | 
    
         
            -
             
     | 
| 
       893 
     | 
    
         
            -
             
     | 
| 
       894 
     | 
    
         
            -
             
     | 
| 
       895 
     | 
    
         
            -
             
     | 
| 
       896 
     | 
    
         
            -
             
     | 
| 
       897 
     | 
    
         
            -
             
     | 
| 
      
 973 
     | 
    
         
            +
                    if not self.use_mla:
         
     | 
| 
      
 974 
     | 
    
         
            +
                        self.full_kv_pool.set_kv_buffer(
         
     | 
| 
      
 975 
     | 
    
         
            +
                            None,
         
     | 
| 
      
 976 
     | 
    
         
            +
                            loc,
         
     | 
| 
      
 977 
     | 
    
         
            +
                            cache_k,
         
     | 
| 
      
 978 
     | 
    
         
            +
                            cache_v,
         
     | 
| 
      
 979 
     | 
    
         
            +
                            k_scale,
         
     | 
| 
      
 980 
     | 
    
         
            +
                            v_scale,
         
     | 
| 
      
 981 
     | 
    
         
            +
                            layer_id_override=layer_id,
         
     | 
| 
      
 982 
     | 
    
         
            +
                        )
         
     | 
| 
      
 983 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 984 
     | 
    
         
            +
                        with self._transfer_id_context(layer):
         
     | 
| 
      
 985 
     | 
    
         
            +
                            self.full_kv_pool.set_kv_buffer(
         
     | 
| 
      
 986 
     | 
    
         
            +
                                layer,
         
     | 
| 
      
 987 
     | 
    
         
            +
                                loc,
         
     | 
| 
      
 988 
     | 
    
         
            +
                                cache_k,
         
     | 
| 
      
 989 
     | 
    
         
            +
                                cache_v,
         
     | 
| 
      
 990 
     | 
    
         
            +
                            )
         
     | 
| 
       898 
991 
     | 
    
         | 
| 
       899 
992 
     | 
    
         
             
                def get_v_head_dim(self):
         
     | 
| 
       900 
993 
     | 
    
         
             
                    return self.full_kv_pool.get_value_buffer(0).shape[-1]
         
     | 
| 
       901 
994 
     | 
    
         | 
| 
      
 995 
     | 
    
         
            +
                def set_mla_kv_buffer(
         
     | 
| 
      
 996 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 997 
     | 
    
         
            +
                    layer: RadixAttention,
         
     | 
| 
      
 998 
     | 
    
         
            +
                    loc: torch.Tensor,
         
     | 
| 
      
 999 
     | 
    
         
            +
                    cache_k_nope: torch.Tensor,
         
     | 
| 
      
 1000 
     | 
    
         
            +
                    cache_k_rope: torch.Tensor,
         
     | 
| 
      
 1001 
     | 
    
         
            +
                ):
         
     | 
| 
      
 1002 
     | 
    
         
            +
                    assert self.use_mla, "set_mla_kv_buffer called when use_mla is False"
         
     | 
| 
      
 1003 
     | 
    
         
            +
                    with self._transfer_id_context(layer):
         
     | 
| 
      
 1004 
     | 
    
         
            +
                        self.full_kv_pool.set_mla_kv_buffer(layer, loc, cache_k_nope, cache_k_rope)
         
     | 
| 
      
 1005 
     | 
    
         
            +
             
     | 
| 
      
 1006 
     | 
    
         
            +
                def get_mla_kv_buffer(
         
     | 
| 
      
 1007 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 1008 
     | 
    
         
            +
                    layer: RadixAttention,
         
     | 
| 
      
 1009 
     | 
    
         
            +
                    loc: torch.Tensor,
         
     | 
| 
      
 1010 
     | 
    
         
            +
                    dst_dtype: Optional[torch.dtype] = None,
         
     | 
| 
      
 1011 
     | 
    
         
            +
                ):
         
     | 
| 
      
 1012 
     | 
    
         
            +
                    assert self.use_mla, "get_mla_kv_buffer called when use_mla is False"
         
     | 
| 
      
 1013 
     | 
    
         
            +
                    with self._transfer_id_context(layer):
         
     | 
| 
      
 1014 
     | 
    
         
            +
                        return self.full_kv_pool.get_mla_kv_buffer(layer, loc, dst_dtype)
         
     | 
| 
      
 1015 
     | 
    
         
            +
             
     | 
| 
       902 
1016 
     | 
    
         | 
| 
       903 
1017 
     | 
    
         
             
            class SWAKVPool(KVCache):
         
     | 
| 
       904 
1018 
     | 
    
         
             
                """KV cache with separate pools for full and SWA attention layers."""
         
     | 
| 
         @@ -1137,10 +1251,10 @@ class AscendTokenToKVPool(MHATokenToKVPool): 
     | 
|
| 
       1137 
1251 
     | 
    
         
             
                    torch_npu._npu_reshape_and_cache(
         
     | 
| 
       1138 
1252 
     | 
    
         
             
                        key=cache_k,
         
     | 
| 
       1139 
1253 
     | 
    
         
             
                        value=cache_v,
         
     | 
| 
       1140 
     | 
    
         
            -
                        key_cache=self.k_buffer[layer_id].view(
         
     | 
| 
      
 1254 
     | 
    
         
            +
                        key_cache=self.k_buffer[layer_id - self.start_layer].view(
         
     | 
| 
       1141 
1255 
     | 
    
         
             
                            -1, self.page_size, self.head_num, self.head_dim
         
     | 
| 
       1142 
1256 
     | 
    
         
             
                        ),
         
     | 
| 
       1143 
     | 
    
         
            -
                        value_cache=self.v_buffer[layer_id].view(
         
     | 
| 
      
 1257 
     | 
    
         
            +
                        value_cache=self.v_buffer[layer_id - self.start_layer].view(
         
     | 
| 
       1144 
1258 
     | 
    
         
             
                            -1, self.page_size, self.head_num, self.head_dim
         
     | 
| 
       1145 
1259 
     | 
    
         
             
                        ),
         
     | 
| 
       1146 
1260 
     | 
    
         
             
                        slot_indices=loc,
         
     | 
| 
         @@ -238,12 +238,16 @@ class MHATokenToKVPoolHost(HostKVCache): 
     | 
|
| 
       238 
238 
     | 
    
         
             
                        raise ValueError(f"Unsupported layout: {self.layout}")
         
     | 
| 
       239 
239 
     | 
    
         
             
                    self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
         
     | 
| 
       240 
240 
     | 
    
         
             
                    self.layout_dim = self.token_stride_size * self.layer_num
         
     | 
| 
       241 
     | 
    
         
            -
                     
     | 
| 
      
 241 
     | 
    
         
            +
                    buffer = torch.empty(
         
     | 
| 
       242 
242 
     | 
    
         
             
                        dims,
         
     | 
| 
       243 
243 
     | 
    
         
             
                        dtype=self.dtype,
         
     | 
| 
       244 
244 
     | 
    
         
             
                        device=self.device,
         
     | 
| 
       245 
     | 
    
         
            -
                        pin_memory=self.pin_memory,
         
     | 
| 
       246 
245 
     | 
    
         
             
                    )
         
     | 
| 
      
 246 
     | 
    
         
            +
                    if self.pin_memory:
         
     | 
| 
      
 247 
     | 
    
         
            +
                        torch.cuda.cudart().cudaHostRegister(
         
     | 
| 
      
 248 
     | 
    
         
            +
                            buffer.data_ptr(), buffer.numel() * buffer.element_size(), 0
         
     | 
| 
      
 249 
     | 
    
         
            +
                        )
         
     | 
| 
      
 250 
     | 
    
         
            +
                    return buffer
         
     | 
| 
       247 
251 
     | 
    
         | 
| 
       248 
252 
     | 
    
         
             
                @property
         
     | 
| 
       249 
253 
     | 
    
         
             
                def k_buffer(self):
         
     | 
| 
         @@ -551,13 +555,16 @@ class MLATokenToKVPoolHost(HostKVCache): 
     | 
|
| 
       551 
555 
     | 
    
         
             
                        self.kv_lora_rank + self.qk_rope_head_dim
         
     | 
| 
       552 
556 
     | 
    
         
             
                    ) * self.dtype.itemsize
         
     | 
| 
       553 
557 
     | 
    
         
             
                    self.layout_dim = self.token_stride_size * self.layer_num
         
     | 
| 
       554 
     | 
    
         
            -
             
     | 
| 
       555 
     | 
    
         
            -
                    return torch.empty(
         
     | 
| 
      
 558 
     | 
    
         
            +
                    buffer = torch.empty(
         
     | 
| 
       556 
559 
     | 
    
         
             
                        dims,
         
     | 
| 
       557 
560 
     | 
    
         
             
                        dtype=self.dtype,
         
     | 
| 
       558 
561 
     | 
    
         
             
                        device=self.device,
         
     | 
| 
       559 
     | 
    
         
            -
                        pin_memory=self.pin_memory,
         
     | 
| 
       560 
562 
     | 
    
         
             
                    )
         
     | 
| 
      
 563 
     | 
    
         
            +
                    if self.pin_memory:
         
     | 
| 
      
 564 
     | 
    
         
            +
                        torch.cuda.cudart().cudaHostRegister(
         
     | 
| 
      
 565 
     | 
    
         
            +
                            buffer.data_ptr(), buffer.numel() * buffer.element_size(), 0
         
     | 
| 
      
 566 
     | 
    
         
            +
                        )
         
     | 
| 
      
 567 
     | 
    
         
            +
                    return buffer
         
     | 
| 
       561 
568 
     | 
    
         | 
| 
       562 
569 
     | 
    
         
             
                def load_to_device_per_layer(
         
     | 
| 
       563 
570 
     | 
    
         
             
                    self, device_pool, host_indices, device_indices, layer_id, io_backend
         
     | 
| 
         @@ -533,6 +533,10 @@ class RadixCache(BasePrefixCache): 
     | 
|
| 
       533 
533 
     | 
    
         
             
                            self.protected_size_ -= len(node.key)
         
     | 
| 
       534 
534 
     | 
    
         
             
                            delta += len(node.key)
         
     | 
| 
       535 
535 
     | 
    
         
             
                        node.lock_ref -= 1
         
     | 
| 
      
 536 
     | 
    
         
            +
                        if node.parent is None:
         
     | 
| 
      
 537 
     | 
    
         
            +
                            assert (
         
     | 
| 
      
 538 
     | 
    
         
            +
                                node is self.root_node
         
     | 
| 
      
 539 
     | 
    
         
            +
                            ), f"This request holds the node from another tree"
         
     | 
| 
       536 
540 
     | 
    
         
             
                        node = node.parent
         
     | 
| 
       537 
541 
     | 
    
         
             
                    return delta
         
     | 
| 
       538 
542 
     | 
    
         | 
| 
         @@ -104,7 +104,7 @@ class MooncakeStoreConfig: 
     | 
|
| 
       104 
104 
     | 
    
         
             
                        device_name=os.getenv("MOONCAKE_DEVICE", ""),
         
     | 
| 
       105 
105 
     | 
    
         
             
                        master_server_address=os.getenv("MOONCAKE_MASTER"),
         
     | 
| 
       106 
106 
     | 
    
         
             
                        master_metrics_port=int(
         
     | 
| 
       107 
     | 
    
         
            -
                            os.getenv("MOONCAKE_MASTER_METRICS_PORT",  
     | 
| 
      
 107 
     | 
    
         
            +
                            os.getenv("MOONCAKE_MASTER_METRICS_PORT", DEFAULT_MASTER_METRICS_PORT)
         
     | 
| 
       108 
108 
     | 
    
         
             
                        ),
         
     | 
| 
       109 
109 
     | 
    
         
             
                        check_server=bool(os.getenv("MOONCAKE_CHECK_SERVER", DEFAULT_CHECK_SERVER)),
         
     | 
| 
       110 
110 
     | 
    
         
             
                    )
         
     | 
    
        sglang/srt/metrics/collector.py
    CHANGED
    
    | 
         @@ -811,6 +811,34 @@ class TokenizerMetricsCollector: 
     | 
|
| 
       811 
811 
     | 
    
         
             
                        buckets=bucket_e2e_request_latency,
         
     | 
| 
       812 
812 
     | 
    
         
             
                    )
         
     | 
| 
       813 
813 
     | 
    
         | 
| 
      
 814 
     | 
    
         
            +
                    # Retraction count histogram
         
     | 
| 
      
 815 
     | 
    
         
            +
                    self.num_retractions = Histogram(
         
     | 
| 
      
 816 
     | 
    
         
            +
                        name="sglang:num_retractions",
         
     | 
| 
      
 817 
     | 
    
         
            +
                        documentation="Histogram of retraction counts per request.",
         
     | 
| 
      
 818 
     | 
    
         
            +
                        labelnames=labels.keys(),
         
     | 
| 
      
 819 
     | 
    
         
            +
                        buckets=[
         
     | 
| 
      
 820 
     | 
    
         
            +
                            0,
         
     | 
| 
      
 821 
     | 
    
         
            +
                            1,
         
     | 
| 
      
 822 
     | 
    
         
            +
                            2,
         
     | 
| 
      
 823 
     | 
    
         
            +
                            3,
         
     | 
| 
      
 824 
     | 
    
         
            +
                            4,
         
     | 
| 
      
 825 
     | 
    
         
            +
                            5,
         
     | 
| 
      
 826 
     | 
    
         
            +
                            6,
         
     | 
| 
      
 827 
     | 
    
         
            +
                            7,
         
     | 
| 
      
 828 
     | 
    
         
            +
                            8,
         
     | 
| 
      
 829 
     | 
    
         
            +
                            9,
         
     | 
| 
      
 830 
     | 
    
         
            +
                            10,
         
     | 
| 
      
 831 
     | 
    
         
            +
                            15,
         
     | 
| 
      
 832 
     | 
    
         
            +
                            20,
         
     | 
| 
      
 833 
     | 
    
         
            +
                            25,
         
     | 
| 
      
 834 
     | 
    
         
            +
                            30,
         
     | 
| 
      
 835 
     | 
    
         
            +
                            40,
         
     | 
| 
      
 836 
     | 
    
         
            +
                            50,
         
     | 
| 
      
 837 
     | 
    
         
            +
                            75,
         
     | 
| 
      
 838 
     | 
    
         
            +
                            100,
         
     | 
| 
      
 839 
     | 
    
         
            +
                        ],
         
     | 
| 
      
 840 
     | 
    
         
            +
                    )
         
     | 
| 
      
 841 
     | 
    
         
            +
             
     | 
| 
       814 
842 
     | 
    
         
             
                def observe_one_finished_request(
         
     | 
| 
       815 
843 
     | 
    
         
             
                    self,
         
     | 
| 
       816 
844 
     | 
    
         
             
                    labels: Dict[str, str],
         
     | 
| 
         @@ -819,6 +847,7 @@ class TokenizerMetricsCollector: 
     | 
|
| 
       819 
847 
     | 
    
         
             
                    cached_tokens: int,
         
     | 
| 
       820 
848 
     | 
    
         
             
                    e2e_latency: float,
         
     | 
| 
       821 
849 
     | 
    
         
             
                    has_grammar: bool,
         
     | 
| 
      
 850 
     | 
    
         
            +
                    retraction_count: int,
         
     | 
| 
       822 
851 
     | 
    
         
             
                ):
         
     | 
| 
       823 
852 
     | 
    
         
             
                    self.prompt_tokens_total.labels(**labels).inc(prompt_tokens)
         
     | 
| 
       824 
853 
     | 
    
         
             
                    self.generation_tokens_total.labels(**labels).inc(generation_tokens)
         
     | 
| 
         @@ -833,6 +862,7 @@ class TokenizerMetricsCollector: 
     | 
|
| 
       833 
862 
     | 
    
         
             
                        self.generation_tokens_histogram.labels(**labels).observe(
         
     | 
| 
       834 
863 
     | 
    
         
             
                            float(generation_tokens)
         
     | 
| 
       835 
864 
     | 
    
         
             
                        )
         
     | 
| 
      
 865 
     | 
    
         
            +
                    self.num_retractions.labels(**labels).observe(retraction_count)
         
     | 
| 
       836 
866 
     | 
    
         | 
| 
       837 
867 
     | 
    
         
             
                def observe_time_to_first_token(self, labels: Dict[str, str], value: float):
         
     | 
| 
       838 
868 
     | 
    
         
             
                    self.histogram_time_to_first_token.labels(**labels).observe(value)
         
     | 
| 
         @@ -840,13 +870,13 @@ class TokenizerMetricsCollector: 
     | 
|
| 
       840 
870 
     | 
    
         
             
                def check_time_to_first_token_straggler(self, value: float) -> bool:
         
     | 
| 
       841 
871 
     | 
    
         
             
                    his = self.histogram_time_to_first_token.labels(**self.labels)
         
     | 
| 
       842 
872 
     | 
    
         
             
                    total_observations = sum(bucket._value for bucket in his._buckets)
         
     | 
| 
       843 
     | 
    
         
            -
                    if total_observations <  
     | 
| 
      
 873 
     | 
    
         
            +
                    if total_observations < 100:
         
     | 
| 
       844 
874 
     | 
    
         
             
                        return False
         
     | 
| 
       845 
     | 
    
         
            -
                     
     | 
| 
      
 875 
     | 
    
         
            +
                    p99_threshold = total_observations * 0.99
         
     | 
| 
       846 
876 
     | 
    
         
             
                    cumulative_count = 0
         
     | 
| 
       847 
877 
     | 
    
         
             
                    for i, bucket in enumerate(his._buckets):
         
     | 
| 
       848 
878 
     | 
    
         
             
                        cumulative_count += bucket._value
         
     | 
| 
       849 
     | 
    
         
            -
                        if cumulative_count >  
     | 
| 
      
 879 
     | 
    
         
            +
                        if cumulative_count > p99_threshold:
         
     | 
| 
       850 
880 
     | 
    
         
             
                            return value >= his._upper_bounds[i]
         
     | 
| 
       851 
881 
     | 
    
         
             
                    return False
         
     | 
| 
       852 
882 
     | 
    
         | 
| 
         @@ -969,3 +999,16 @@ class StorageMetricsCollector: 
     | 
|
| 
       969 
999 
     | 
    
         
             
                        self._log_histogram(self.histogram_prefetch_bandwidth, v)
         
     | 
| 
       970 
1000 
     | 
    
         
             
                    for v in storage_metrics.backup_bandwidth:
         
     | 
| 
       971 
1001 
     | 
    
         
             
                        self._log_histogram(self.histogram_backup_bandwidth, v)
         
     | 
| 
      
 1002 
     | 
    
         
            +
             
     | 
| 
      
 1003 
     | 
    
         
            +
             
     | 
| 
      
 1004 
     | 
    
         
            +
            class ExpertDispatchCollector:
         
     | 
| 
      
 1005 
     | 
    
         
            +
                def __init__(self, ep_size: int) -> None:
         
     | 
| 
      
 1006 
     | 
    
         
            +
                    from prometheus_client import Histogram
         
     | 
| 
      
 1007 
     | 
    
         
            +
             
     | 
| 
      
 1008 
     | 
    
         
            +
                    ep_size_buckets = [i for i in range(ep_size)]
         
     | 
| 
      
 1009 
     | 
    
         
            +
                    self.eplb_gpu_physical_count = Histogram(
         
     | 
| 
      
 1010 
     | 
    
         
            +
                        name="sglang:eplb_gpu_physical_count",
         
     | 
| 
      
 1011 
     | 
    
         
            +
                        documentation="The selected count of physical experts on each layer and GPU rank.",
         
     | 
| 
      
 1012 
     | 
    
         
            +
                        labelnames={"layer"},
         
     | 
| 
      
 1013 
     | 
    
         
            +
                        buckets=ep_size_buckets,
         
     | 
| 
      
 1014 
     | 
    
         
            +
                    )
         
     | 
| 
         @@ -21,12 +21,14 @@ import inspect 
     | 
|
| 
       21 
21 
     | 
    
         
             
            import logging
         
     | 
| 
       22 
22 
     | 
    
         
             
            import os
         
     | 
| 
       23 
23 
     | 
    
         
             
            from contextlib import contextmanager
         
     | 
| 
      
 24 
     | 
    
         
            +
            from functools import partial
         
     | 
| 
       24 
25 
     | 
    
         
             
            from typing import TYPE_CHECKING, Callable, Optional, Union
         
     | 
| 
       25 
26 
     | 
    
         | 
| 
       26 
27 
     | 
    
         
             
            import torch
         
     | 
| 
       27 
28 
     | 
    
         
             
            import tqdm
         
     | 
| 
       28 
29 
     | 
    
         
             
            from torch.profiler import ProfilerActivity, profile
         
     | 
| 
       29 
30 
     | 
    
         | 
| 
      
 31 
     | 
    
         
            +
            from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH
         
     | 
| 
       30 
32 
     | 
    
         
             
            from sglang.srt.custom_op import CustomOp
         
     | 
| 
       31 
33 
     | 
    
         
             
            from sglang.srt.distributed import get_tensor_model_parallel_rank
         
     | 
| 
       32 
34 
     | 
    
         
             
            from sglang.srt.distributed.device_communicators.pynccl_allocator import (
         
     | 
| 
         @@ -64,6 +66,7 @@ from sglang.srt.utils import ( 
     | 
|
| 
       64 
66 
     | 
    
         
             
                require_mlp_tp_gather,
         
     | 
| 
       65 
67 
     | 
    
         
             
            )
         
     | 
| 
       66 
68 
     | 
    
         
             
            from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
         
     | 
| 
      
 69 
     | 
    
         
            +
            from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
         
     | 
| 
       67 
70 
     | 
    
         | 
| 
       68 
71 
     | 
    
         
             
            try:
         
     | 
| 
       69 
72 
     | 
    
         
             
                from kt_kernel import AMXMoEWrapper
         
     | 
| 
         @@ -320,11 +323,11 @@ class CudaGraphRunner: 
     | 
|
| 
       320 
323 
     | 
    
         
             
                            self.pp_proxy_tensors = {
         
     | 
| 
       321 
324 
     | 
    
         
             
                                "hidden_states": torch.zeros(
         
     | 
| 
       322 
325 
     | 
    
         
             
                                    (self.max_bs, self.model_runner.model_config.hidden_size),
         
     | 
| 
       323 
     | 
    
         
            -
                                    dtype= 
     | 
| 
      
 326 
     | 
    
         
            +
                                    dtype=self.model_runner.model_config.dtype,
         
     | 
| 
       324 
327 
     | 
    
         
             
                                ),
         
     | 
| 
       325 
328 
     | 
    
         
             
                                "residual": torch.zeros(
         
     | 
| 
       326 
329 
     | 
    
         
             
                                    (self.max_bs, self.model_runner.model_config.hidden_size),
         
     | 
| 
       327 
     | 
    
         
            -
                                    dtype= 
     | 
| 
      
 330 
     | 
    
         
            +
                                    dtype=self.model_runner.model_config.dtype,
         
     | 
| 
       328 
331 
     | 
    
         
             
                                ),
         
     | 
| 
       329 
332 
     | 
    
         
             
                            }
         
     | 
| 
       330 
333 
     | 
    
         | 
| 
         @@ -518,7 +521,16 @@ class CudaGraphRunner: 
     | 
|
| 
       518 
521 
     | 
    
         
             
                        logger.info(log_message)
         
     | 
| 
       519 
522 
     | 
    
         | 
| 
       520 
523 
     | 
    
         
             
                def _capture_graph(self, graph, pool, stream, run_once_fn):
         
     | 
| 
       521 
     | 
    
         
            -
                     
     | 
| 
      
 524 
     | 
    
         
            +
                    memory_saver_adapter = TorchMemorySaverAdapter.create(
         
     | 
| 
      
 525 
     | 
    
         
            +
                        enable=self.model_runner.server_args.enable_memory_saver
         
     | 
| 
      
 526 
     | 
    
         
            +
                        and get_bool_env_var("SGLANG_MEMORY_SAVER_CUDA_GRAPH")
         
     | 
| 
      
 527 
     | 
    
         
            +
                    )
         
     | 
| 
      
 528 
     | 
    
         
            +
                    graph_fn = (
         
     | 
| 
      
 529 
     | 
    
         
            +
                        partial(memory_saver_adapter.cuda_graph, tag=GPU_MEMORY_TYPE_CUDA_GRAPH)
         
     | 
| 
      
 530 
     | 
    
         
            +
                        if memory_saver_adapter.enabled
         
     | 
| 
      
 531 
     | 
    
         
            +
                        else self.device_module.graph
         
     | 
| 
      
 532 
     | 
    
         
            +
                    )
         
     | 
| 
      
 533 
     | 
    
         
            +
                    with graph_fn(cuda_graph=graph, pool=pool, stream=stream):
         
     | 
| 
       522 
534 
     | 
    
         
             
                        out = run_once_fn()
         
     | 
| 
       523 
535 
     | 
    
         
             
                    return out
         
     | 
| 
       524 
536 
     | 
    
         | 
| 
         @@ -90,12 +90,9 @@ class ForwardMode(IntEnum): 
     | 
|
| 
       90 
90 
     | 
    
         
             
                        self == ForwardMode.EXTEND
         
     | 
| 
       91 
91 
     | 
    
         
             
                        or self == ForwardMode.MIXED
         
     | 
| 
       92 
92 
     | 
    
         
             
                        or self == ForwardMode.DRAFT_EXTEND
         
     | 
| 
       93 
     | 
    
         
            -
                        or (
         
     | 
| 
       94 
     | 
    
         
            -
                            self == ForwardMode.DRAFT_EXTEND_V2
         
     | 
| 
       95 
     | 
    
         
            -
                            if include_draft_extend_v2
         
     | 
| 
       96 
     | 
    
         
            -
                            else False
         
     | 
| 
       97 
     | 
    
         
            -
                        )
         
     | 
| 
      
 93 
     | 
    
         
            +
                        or (include_draft_extend_v2 and self == ForwardMode.DRAFT_EXTEND_V2)
         
     | 
| 
       98 
94 
     | 
    
         
             
                        or self == ForwardMode.TARGET_VERIFY
         
     | 
| 
      
 95 
     | 
    
         
            +
                        or self == ForwardMode.SPLIT_PREFILL
         
     | 
| 
       99 
96 
     | 
    
         
             
                    )
         
     | 
| 
       100 
97 
     | 
    
         | 
| 
       101 
98 
     | 
    
         
             
                def is_decode(self):
         
     | 
| 
         @@ -114,22 +111,21 @@ class ForwardMode(IntEnum): 
     | 
|
| 
       114 
111 
     | 
    
         
             
                    return self == ForwardMode.TARGET_VERIFY
         
     | 
| 
       115 
112 
     | 
    
         | 
| 
       116 
113 
     | 
    
         
             
                def is_draft_extend(self, include_v2: bool = False):
         
     | 
| 
       117 
     | 
    
         
            -
                     
     | 
| 
       118 
     | 
    
         
            -
                         
     | 
| 
       119 
     | 
    
         
            -
             
     | 
| 
       120 
     | 
    
         
            -
                        )
         
     | 
| 
       121 
     | 
    
         
            -
                    return self == ForwardMode.DRAFT_EXTEND
         
     | 
| 
      
 114 
     | 
    
         
            +
                    return self == ForwardMode.DRAFT_EXTEND or (
         
     | 
| 
      
 115 
     | 
    
         
            +
                        include_v2 and self == ForwardMode.DRAFT_EXTEND_V2
         
     | 
| 
      
 116 
     | 
    
         
            +
                    )
         
     | 
| 
       122 
117 
     | 
    
         | 
| 
       123 
118 
     | 
    
         
             
                def is_draft_extend_v2(self):
         
     | 
| 
       124 
119 
     | 
    
         
             
                    # For fixed shape logits output in v2 eagle worker
         
     | 
| 
       125 
120 
     | 
    
         
             
                    return self == ForwardMode.DRAFT_EXTEND_V2
         
     | 
| 
       126 
121 
     | 
    
         | 
| 
       127 
     | 
    
         
            -
                def is_extend_or_draft_extend_or_mixed(self):
         
     | 
| 
      
 122 
     | 
    
         
            +
                def is_extend_or_draft_extend_or_mixed(self, include_draft_extend_v2: bool = False):
         
     | 
| 
       128 
123 
     | 
    
         
             
                    return (
         
     | 
| 
       129 
124 
     | 
    
         
             
                        self == ForwardMode.EXTEND
         
     | 
| 
       130 
125 
     | 
    
         
             
                        or self == ForwardMode.DRAFT_EXTEND
         
     | 
| 
       131 
126 
     | 
    
         
             
                        or self == ForwardMode.MIXED
         
     | 
| 
       132 
127 
     | 
    
         
             
                        or self == ForwardMode.SPLIT_PREFILL
         
     | 
| 
      
 128 
     | 
    
         
            +
                        or (include_draft_extend_v2 and self == ForwardMode.DRAFT_EXTEND_V2)
         
     | 
| 
       133 
129 
     | 
    
         
             
                    )
         
     | 
| 
       134 
130 
     | 
    
         | 
| 
       135 
131 
     | 
    
         
             
                def is_cuda_graph(self):
         
     | 
| 
         @@ -319,6 +315,9 @@ class ForwardBatch: 
     | 
|
| 
       319 
315 
     | 
    
         
             
                tbo_parent_token_range: Optional[Tuple[int, int]] = None
         
     | 
| 
       320 
316 
     | 
    
         
             
                tbo_children: Optional[List[ForwardBatch]] = None
         
     | 
| 
       321 
317 
     | 
    
         | 
| 
      
 318 
     | 
    
         
            +
                # For matryoshka embeddings
         
     | 
| 
      
 319 
     | 
    
         
            +
                dimensions: Optional[list[int]] = None
         
     | 
| 
      
 320 
     | 
    
         
            +
             
     | 
| 
       322 
321 
     | 
    
         
             
                @classmethod
         
     | 
| 
       323 
322 
     | 
    
         
             
                def init_new(
         
     | 
| 
       324 
323 
     | 
    
         
             
                    cls,
         
     | 
| 
         @@ -360,6 +359,7 @@ class ForwardBatch: 
     | 
|
| 
       360 
359 
     | 
    
         
             
                        input_embeds=batch.input_embeds,
         
     | 
| 
       361 
360 
     | 
    
         
             
                        token_type_ids=batch.token_type_ids,
         
     | 
| 
       362 
361 
     | 
    
         
             
                        tbo_split_seq_index=batch.tbo_split_seq_index,
         
     | 
| 
      
 362 
     | 
    
         
            +
                        dimensions=batch.dimensions,
         
     | 
| 
       363 
363 
     | 
    
         
             
                    )
         
     | 
| 
       364 
364 
     | 
    
         
             
                    device = model_runner.device
         
     | 
| 
       365 
365 
     | 
    
         |