sglang 0.5.3.post2__py3-none-any.whl → 0.5.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +13 -8
 - sglang/srt/disaggregation/base/conn.py +17 -4
 - sglang/srt/disaggregation/common/conn.py +1 -0
 - sglang/srt/disaggregation/decode.py +113 -8
 - sglang/srt/disaggregation/fake/conn.py +11 -3
 - sglang/srt/disaggregation/mooncake/conn.py +148 -17
 - sglang/srt/disaggregation/nixl/conn.py +7 -1
 - sglang/srt/disaggregation/prefill.py +71 -1
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -3
 - sglang/srt/environ.py +3 -3
 - sglang/srt/layers/attention/ascend_backend.py +17 -0
 - sglang/srt/layers/layernorm.py +41 -9
 - sglang/srt/layers/logits_processor.py +1 -1
 - sglang/srt/layers/moe/utils.py +4 -2
 - sglang/srt/layers/rotary_embedding.py +16 -2
 - sglang/srt/layers/sampler.py +3 -3
 - sglang/srt/managers/scheduler.py +0 -6
 - sglang/srt/mem_cache/allocator_ascend.py +1 -1
 - sglang/srt/mem_cache/common.py +1 -5
 - sglang/srt/mem_cache/memory_pool.py +248 -137
 - sglang/srt/model_executor/model_runner.py +28 -13
 - sglang/srt/model_executor/npu_graph_runner.py +2 -2
 - sglang/srt/model_loader/weight_utils.py +2 -2
 - sglang/srt/models/deepseek_v2.py +1 -0
 - sglang/srt/models/glm4_moe.py +4 -2
 - sglang/srt/server_args.py +31 -9
 - sglang/srt/speculative/eagle_worker.py +2 -2
 - sglang/srt/speculative/spec_info.py +2 -0
 - sglang/srt/speculative/standalone_worker.py +1 -1
 - sglang/test/runners.py +1 -1
 - sglang/test/send_one.py +27 -1
 - sglang/test/test_disaggregation_utils.py +33 -15
 - sglang/test/test_utils.py +37 -2
 - sglang/version.py +1 -1
 - {sglang-0.5.3.post2.dist-info → sglang-0.5.3.post3.dist-info}/METADATA +1 -1
 - {sglang-0.5.3.post2.dist-info → sglang-0.5.3.post3.dist-info}/RECORD +39 -39
 - {sglang-0.5.3.post2.dist-info → sglang-0.5.3.post3.dist-info}/WHEEL +0 -0
 - {sglang-0.5.3.post2.dist-info → sglang-0.5.3.post3.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.3.post2.dist-info → sglang-0.5.3.post3.dist-info}/top_level.txt +0 -0
 
    
        sglang/bench_one_batch.py
    CHANGED
    
    | 
         @@ -72,6 +72,8 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm 
     | 
|
| 
       72 
72 
     | 
    
         
             
            from sglang.srt.utils import (
         
     | 
| 
       73 
73 
     | 
    
         
             
                configure_logger,
         
     | 
| 
       74 
74 
     | 
    
         
             
                get_bool_env_var,
         
     | 
| 
      
 75 
     | 
    
         
            +
                is_cuda_alike,
         
     | 
| 
      
 76 
     | 
    
         
            +
                is_xpu,
         
     | 
| 
       75 
77 
     | 
    
         
             
                kill_process_tree,
         
     | 
| 
       76 
78 
     | 
    
         
             
                require_mlp_sync,
         
     | 
| 
       77 
79 
     | 
    
         
             
                require_mlp_tp_gather,
         
     | 
| 
         @@ -80,6 +82,15 @@ from sglang.srt.utils import ( 
     | 
|
| 
       80 
82 
     | 
    
         
             
            )
         
     | 
| 
       81 
83 
     | 
    
         
             
            from sglang.srt.utils.hf_transformers_utils import get_tokenizer
         
     | 
| 
       82 
84 
     | 
    
         | 
| 
      
 85 
     | 
    
         
            +
            profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
         
     | 
| 
      
 86 
     | 
    
         
            +
                profiler_activity
         
     | 
| 
      
 87 
     | 
    
         
            +
                for available, profiler_activity in [
         
     | 
| 
      
 88 
     | 
    
         
            +
                    (is_cuda_alike(), torch.profiler.ProfilerActivity.CUDA),
         
     | 
| 
      
 89 
     | 
    
         
            +
                    (is_xpu(), torch.profiler.ProfilerActivity.XPU),
         
     | 
| 
      
 90 
     | 
    
         
            +
                ]
         
     | 
| 
      
 91 
     | 
    
         
            +
                if available
         
     | 
| 
      
 92 
     | 
    
         
            +
            ]
         
     | 
| 
      
 93 
     | 
    
         
            +
             
     | 
| 
       83 
94 
     | 
    
         | 
| 
       84 
95 
     | 
    
         
             
            @dataclasses.dataclass
         
     | 
| 
       85 
96 
     | 
    
         
             
            class BenchArgs:
         
     | 
| 
         @@ -424,10 +435,7 @@ def latency_test_run_once( 
     | 
|
| 
       424 
435 
     | 
    
         
             
                profiler = None
         
     | 
| 
       425 
436 
     | 
    
         
             
                if profile:
         
     | 
| 
       426 
437 
     | 
    
         
             
                    profiler = torch.profiler.profile(
         
     | 
| 
       427 
     | 
    
         
            -
                        activities= 
     | 
| 
       428 
     | 
    
         
            -
                            torch.profiler.ProfilerActivity.CPU,
         
     | 
| 
       429 
     | 
    
         
            -
                            torch.profiler.ProfilerActivity.CUDA,
         
     | 
| 
       430 
     | 
    
         
            -
                        ],
         
     | 
| 
      
 438 
     | 
    
         
            +
                        activities=profile_activities,
         
     | 
| 
       431 
439 
     | 
    
         
             
                        with_stack=True,
         
     | 
| 
       432 
440 
     | 
    
         
             
                        record_shapes=profile_record_shapes,
         
     | 
| 
       433 
441 
     | 
    
         
             
                    )
         
     | 
| 
         @@ -460,10 +468,7 @@ def latency_test_run_once( 
     | 
|
| 
       460 
468 
     | 
    
         
             
                    if profile and i == output_len / 2:
         
     | 
| 
       461 
469 
     | 
    
         
             
                        profiler = None
         
     | 
| 
       462 
470 
     | 
    
         
             
                        profiler = torch.profiler.profile(
         
     | 
| 
       463 
     | 
    
         
            -
                            activities= 
     | 
| 
       464 
     | 
    
         
            -
                                torch.profiler.ProfilerActivity.CPU,
         
     | 
| 
       465 
     | 
    
         
            -
                                torch.profiler.ProfilerActivity.CUDA,
         
     | 
| 
       466 
     | 
    
         
            -
                            ],
         
     | 
| 
      
 471 
     | 
    
         
            +
                            activities=profile_activities,
         
     | 
| 
       467 
472 
     | 
    
         
             
                            with_stack=True,
         
     | 
| 
       468 
473 
     | 
    
         
             
                            record_shapes=profile_record_shapes,
         
     | 
| 
       469 
474 
     | 
    
         
             
                        )
         
     | 
| 
         @@ -20,6 +20,10 @@ class KVArgs: 
     | 
|
| 
       20 
20 
     | 
    
         
             
                aux_data_ptrs: List[int]
         
     | 
| 
       21 
21 
     | 
    
         
             
                aux_data_lens: List[int]
         
     | 
| 
       22 
22 
     | 
    
         
             
                aux_item_lens: List[int]
         
     | 
| 
      
 23 
     | 
    
         
            +
                state_data_ptrs: List[int]
         
     | 
| 
      
 24 
     | 
    
         
            +
                state_data_lens: List[int]
         
     | 
| 
      
 25 
     | 
    
         
            +
                state_item_lens: List[int]
         
     | 
| 
      
 26 
     | 
    
         
            +
                state_type: str  # "none", "mamba", "swa"
         
     | 
| 
       23 
27 
     | 
    
         
             
                ib_device: str
         
     | 
| 
       24 
28 
     | 
    
         
             
                ib_traffic_class: str
         
     | 
| 
       25 
29 
     | 
    
         
             
                gpu_id: int
         
     | 
| 
         @@ -76,9 +80,13 @@ class BaseKVSender(ABC): 
     | 
|
| 
       76 
80 
     | 
    
         
             
                    ...
         
     | 
| 
       77 
81 
     | 
    
         | 
| 
       78 
82 
     | 
    
         
             
                @abstractmethod
         
     | 
| 
       79 
     | 
    
         
            -
                def send( 
     | 
| 
      
 83 
     | 
    
         
            +
                def send(
         
     | 
| 
      
 84 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 85 
     | 
    
         
            +
                    kv_indices: npt.NDArray[np.int32],
         
     | 
| 
      
 86 
     | 
    
         
            +
                    state_indices: Optional[List[int]] = None,
         
     | 
| 
      
 87 
     | 
    
         
            +
                ):
         
     | 
| 
       80 
88 
     | 
    
         
             
                    """
         
     | 
| 
       81 
     | 
    
         
            -
                    Send the kv cache at the given kv indices to the decoder server
         
     | 
| 
      
 89 
     | 
    
         
            +
                    Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server
         
     | 
| 
       82 
90 
     | 
    
         
             
                    """
         
     | 
| 
       83 
91 
     | 
    
         
             
                    ...
         
     | 
| 
       84 
92 
     | 
    
         | 
| 
         @@ -108,9 +116,14 @@ class BaseKVReceiver(ABC): 
     | 
|
| 
       108 
116 
     | 
    
         
             
                ): ...
         
     | 
| 
       109 
117 
     | 
    
         | 
| 
       110 
118 
     | 
    
         
             
                @abstractmethod
         
     | 
| 
       111 
     | 
    
         
            -
                def init( 
     | 
| 
      
 119 
     | 
    
         
            +
                def init(
         
     | 
| 
      
 120 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 121 
     | 
    
         
            +
                    kv_indices: npt.NDArray[np.int32],
         
     | 
| 
      
 122 
     | 
    
         
            +
                    aux_index: Optional[int] = None,
         
     | 
| 
      
 123 
     | 
    
         
            +
                    state_indices: Optional[List[int]] = None,
         
     | 
| 
      
 124 
     | 
    
         
            +
                ):
         
     | 
| 
       112 
125 
     | 
    
         
             
                    """
         
     | 
| 
       113 
     | 
    
         
            -
                    Notify the prefill server about the kv indices  
     | 
| 
      
 126 
     | 
    
         
            +
                    Notify the prefill server about the kv indices, aux index, and state_indices.
         
     | 
| 
       114 
127 
     | 
    
         
             
                    """
         
     | 
| 
       115 
128 
     | 
    
         
             
                    ...
         
     | 
| 
       116 
129 
     | 
    
         | 
| 
         @@ -25,11 +25,12 @@ import time 
     | 
|
| 
       25 
25 
     | 
    
         
             
            from collections import deque
         
     | 
| 
       26 
26 
     | 
    
         
             
            from dataclasses import dataclass
         
     | 
| 
       27 
27 
     | 
    
         
             
            from http import HTTPStatus
         
     | 
| 
       28 
     | 
    
         
            -
            from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
         
     | 
| 
      
 28 
     | 
    
         
            +
            from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
         
     | 
| 
       29 
29 
     | 
    
         | 
| 
       30 
30 
     | 
    
         
             
            import torch
         
     | 
| 
       31 
31 
     | 
    
         
             
            from torch.distributed import ProcessGroup
         
     | 
| 
       32 
32 
     | 
    
         | 
| 
      
 33 
     | 
    
         
            +
            from sglang.srt.configs.mamba_utils import Mamba2CacheParams
         
     | 
| 
       33 
34 
     | 
    
         
             
            from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
         
     | 
| 
       34 
35 
     | 
    
         
             
            from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
         
     | 
| 
       35 
36 
     | 
    
         
             
            from sglang.srt.disaggregation.utils import (
         
     | 
| 
         @@ -47,9 +48,19 @@ from sglang.srt.disaggregation.utils import ( 
     | 
|
| 
       47 
48 
     | 
    
         
             
            )
         
     | 
| 
       48 
49 
     | 
    
         
             
            from sglang.srt.layers.dp_attention import get_attention_tp_size
         
     | 
| 
       49 
50 
     | 
    
         
             
            from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
         
     | 
| 
       50 
     | 
    
         
            -
            from sglang.srt.mem_cache.allocator import  
     | 
| 
      
 51 
     | 
    
         
            +
            from sglang.srt.mem_cache.allocator import (
         
     | 
| 
      
 52 
     | 
    
         
            +
                BaseTokenToKVPoolAllocator,
         
     | 
| 
      
 53 
     | 
    
         
            +
                SWATokenToKVPoolAllocator,
         
     | 
| 
      
 54 
     | 
    
         
            +
            )
         
     | 
| 
       51 
55 
     | 
    
         
             
            from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
         
     | 
| 
       52 
     | 
    
         
            -
            from sglang.srt.mem_cache.memory_pool import  
     | 
| 
      
 56 
     | 
    
         
            +
            from sglang.srt.mem_cache.memory_pool import (
         
     | 
| 
      
 57 
     | 
    
         
            +
                HybridLinearKVPool,
         
     | 
| 
      
 58 
     | 
    
         
            +
                HybridReqToTokenPool,
         
     | 
| 
      
 59 
     | 
    
         
            +
                KVCache,
         
     | 
| 
      
 60 
     | 
    
         
            +
                NSATokenToKVPool,
         
     | 
| 
      
 61 
     | 
    
         
            +
                ReqToTokenPool,
         
     | 
| 
      
 62 
     | 
    
         
            +
                SWAKVPool,
         
     | 
| 
      
 63 
     | 
    
         
            +
            )
         
     | 
| 
       53 
64 
     | 
    
         
             
            from sglang.srt.model_executor.forward_batch_info import ForwardMode
         
     | 
| 
       54 
65 
     | 
    
         
             
            from sglang.srt.utils import get_int_env_var, require_mlp_sync
         
     | 
| 
       55 
66 
     | 
    
         
             
            from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
         
     | 
| 
         @@ -124,6 +135,35 @@ class DecodeReqToTokenPool: 
     | 
|
| 
       124 
135 
     | 
    
         
             
                    self.free_slots = list(range(self.size + self.pre_alloc_size))
         
     | 
| 
       125 
136 
     | 
    
         | 
| 
       126 
137 
     | 
    
         | 
| 
      
 138 
     | 
    
         
            +
            class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
         
     | 
| 
      
 139 
     | 
    
         
            +
             
     | 
| 
      
 140 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 141 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 142 
     | 
    
         
            +
                    size: int,
         
     | 
| 
      
 143 
     | 
    
         
            +
                    max_context_len: int,
         
     | 
| 
      
 144 
     | 
    
         
            +
                    device: str,
         
     | 
| 
      
 145 
     | 
    
         
            +
                    enable_memory_saver: bool,
         
     | 
| 
      
 146 
     | 
    
         
            +
                    cache_params: "Mamba2CacheParams",
         
     | 
| 
      
 147 
     | 
    
         
            +
                    speculative_num_draft_tokens: int,
         
     | 
| 
      
 148 
     | 
    
         
            +
                    pre_alloc_size: int,
         
     | 
| 
      
 149 
     | 
    
         
            +
                ):
         
     | 
| 
      
 150 
     | 
    
         
            +
                    DecodeReqToTokenPool.__init__(
         
     | 
| 
      
 151 
     | 
    
         
            +
                        self,
         
     | 
| 
      
 152 
     | 
    
         
            +
                        size=size,
         
     | 
| 
      
 153 
     | 
    
         
            +
                        max_context_len=max_context_len,
         
     | 
| 
      
 154 
     | 
    
         
            +
                        device=device,
         
     | 
| 
      
 155 
     | 
    
         
            +
                        enable_memory_saver=enable_memory_saver,
         
     | 
| 
      
 156 
     | 
    
         
            +
                        pre_alloc_size=pre_alloc_size,
         
     | 
| 
      
 157 
     | 
    
         
            +
                    )
         
     | 
| 
      
 158 
     | 
    
         
            +
                    self._init_mamba_pool(
         
     | 
| 
      
 159 
     | 
    
         
            +
                        size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens
         
     | 
| 
      
 160 
     | 
    
         
            +
                    )
         
     | 
| 
      
 161 
     | 
    
         
            +
             
     | 
| 
      
 162 
     | 
    
         
            +
                def clear(self):
         
     | 
| 
      
 163 
     | 
    
         
            +
                    self.free_slots = list(range(self.size + self.pre_alloc_size))
         
     | 
| 
      
 164 
     | 
    
         
            +
                    self.mamba_pool.clear()
         
     | 
| 
      
 165 
     | 
    
         
            +
             
     | 
| 
      
 166 
     | 
    
         
            +
             
     | 
| 
       127 
167 
     | 
    
         
             
            @dataclass
         
     | 
| 
       128 
168 
     | 
    
         
             
            class DecodeRequest:
         
     | 
| 
       129 
169 
     | 
    
         
             
                req: Req
         
     | 
| 
         @@ -217,6 +257,28 @@ class DecodePreallocQueue: 
     | 
|
| 
       217 
257 
     | 
    
         
             
                        self.metadata_buffers.get_buf_infos()
         
     | 
| 
       218 
258 
     | 
    
         
             
                    )
         
     | 
| 
       219 
259 
     | 
    
         | 
| 
      
 260 
     | 
    
         
            +
                    if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
         
     | 
| 
      
 261 
     | 
    
         
            +
                        state_data_ptrs, state_data_lens, state_item_lens = (
         
     | 
| 
      
 262 
     | 
    
         
            +
                            self.token_to_kv_pool.get_state_buf_infos()
         
     | 
| 
      
 263 
     | 
    
         
            +
                        )
         
     | 
| 
      
 264 
     | 
    
         
            +
                        kv_args.state_data_ptrs = state_data_ptrs
         
     | 
| 
      
 265 
     | 
    
         
            +
                        kv_args.state_data_lens = state_data_lens
         
     | 
| 
      
 266 
     | 
    
         
            +
                        kv_args.state_item_lens = state_item_lens
         
     | 
| 
      
 267 
     | 
    
         
            +
             
     | 
| 
      
 268 
     | 
    
         
            +
                        if isinstance(self.token_to_kv_pool, SWAKVPool):
         
     | 
| 
      
 269 
     | 
    
         
            +
                            kv_args.state_type = "swa"
         
     | 
| 
      
 270 
     | 
    
         
            +
                        elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
         
     | 
| 
      
 271 
     | 
    
         
            +
                            kv_args.state_type = "mamba"
         
     | 
| 
      
 272 
     | 
    
         
            +
                        elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
         
     | 
| 
      
 273 
     | 
    
         
            +
                            kv_args.state_type = "nsa"
         
     | 
| 
      
 274 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 275 
     | 
    
         
            +
                            kv_args.state_type = "none"
         
     | 
| 
      
 276 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 277 
     | 
    
         
            +
                        kv_args.state_data_ptrs = []
         
     | 
| 
      
 278 
     | 
    
         
            +
                        kv_args.state_data_lens = []
         
     | 
| 
      
 279 
     | 
    
         
            +
                        kv_args.state_item_lens = []
         
     | 
| 
      
 280 
     | 
    
         
            +
                        kv_args.state_type = "none"
         
     | 
| 
      
 281 
     | 
    
         
            +
             
     | 
| 
       220 
282 
     | 
    
         
             
                    kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
         
     | 
| 
       221 
283 
     | 
    
         
             
                    kv_args.gpu_id = self.scheduler.gpu_id
         
     | 
| 
       222 
284 
     | 
    
         
             
                    kv_manager_class: Type[BaseKVManager] = get_kv_class(
         
     | 
| 
         @@ -414,16 +476,56 @@ class DecodePreallocQueue: 
     | 
|
| 
       414 
476 
     | 
    
         
             
                            .cpu()
         
     | 
| 
       415 
477 
     | 
    
         
             
                            .numpy()
         
     | 
| 
       416 
478 
     | 
    
         
             
                        )
         
     | 
| 
      
 479 
     | 
    
         
            +
                        page_size = self.token_to_kv_pool_allocator.page_size
         
     | 
| 
      
 480 
     | 
    
         
            +
             
     | 
| 
      
 481 
     | 
    
         
            +
                        # Prepare extra pool indices for hybrid models
         
     | 
| 
      
 482 
     | 
    
         
            +
                        if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
         
     | 
| 
      
 483 
     | 
    
         
            +
                            # Mamba hybrid model: single mamba state index
         
     | 
| 
      
 484 
     | 
    
         
            +
                            state_indices = [
         
     | 
| 
      
 485 
     | 
    
         
            +
                                self.req_to_token_pool.req_index_to_mamba_index_mapping[
         
     | 
| 
      
 486 
     | 
    
         
            +
                                    decode_req.req.req_pool_idx
         
     | 
| 
      
 487 
     | 
    
         
            +
                                ]
         
     | 
| 
      
 488 
     | 
    
         
            +
                                .cpu()
         
     | 
| 
      
 489 
     | 
    
         
            +
                                .numpy()
         
     | 
| 
      
 490 
     | 
    
         
            +
                            ]
         
     | 
| 
      
 491 
     | 
    
         
            +
                        elif isinstance(self.token_to_kv_pool, SWAKVPool):
         
     | 
| 
      
 492 
     | 
    
         
            +
                            # SWA hybrid model: send decode-side SWA window indices
         
     | 
| 
      
 493 
     | 
    
         
            +
                            seq_len = len(decode_req.req.origin_input_ids)
         
     | 
| 
      
 494 
     | 
    
         
            +
                            window_size = self.scheduler.sliding_window_size
         
     | 
| 
      
 495 
     | 
    
         
            +
             
     | 
| 
      
 496 
     | 
    
         
            +
                            window_start = max(0, seq_len - window_size)
         
     | 
| 
      
 497 
     | 
    
         
            +
                            window_start = (window_start // page_size) * page_size
         
     | 
| 
      
 498 
     | 
    
         
            +
                            window_kv_indices_full = self.req_to_token_pool.req_to_token[
         
     | 
| 
      
 499 
     | 
    
         
            +
                                decode_req.req.req_pool_idx, window_start:seq_len
         
     | 
| 
      
 500 
     | 
    
         
            +
                            ]
         
     | 
| 
      
 501 
     | 
    
         
            +
             
     | 
| 
      
 502 
     | 
    
         
            +
                            # Translate to SWA pool indices
         
     | 
| 
      
 503 
     | 
    
         
            +
                            window_kv_indices_swa = (
         
     | 
| 
      
 504 
     | 
    
         
            +
                                self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
         
     | 
| 
      
 505 
     | 
    
         
            +
                                    window_kv_indices_full
         
     | 
| 
      
 506 
     | 
    
         
            +
                                )
         
     | 
| 
      
 507 
     | 
    
         
            +
                            )
         
     | 
| 
      
 508 
     | 
    
         
            +
                            state_indices = window_kv_indices_swa.cpu().numpy()
         
     | 
| 
      
 509 
     | 
    
         
            +
                            state_indices = kv_to_page_indices(state_indices, page_size)
         
     | 
| 
      
 510 
     | 
    
         
            +
                        elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
         
     | 
| 
      
 511 
     | 
    
         
            +
                            seq_len = len(decode_req.req.origin_input_ids)
         
     | 
| 
      
 512 
     | 
    
         
            +
                            kv_indices_full = self.req_to_token_pool.req_to_token[
         
     | 
| 
      
 513 
     | 
    
         
            +
                                decode_req.req.req_pool_idx, :seq_len
         
     | 
| 
      
 514 
     | 
    
         
            +
                            ]
         
     | 
| 
      
 515 
     | 
    
         
            +
                            state_indices = kv_indices_full.cpu().numpy()
         
     | 
| 
      
 516 
     | 
    
         
            +
                            state_indices = kv_to_page_indices(state_indices, page_size)
         
     | 
| 
      
 517 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 518 
     | 
    
         
            +
                            state_indices = None
         
     | 
| 
       417 
519 
     | 
    
         | 
| 
       418 
520 
     | 
    
         
             
                        decode_req.metadata_buffer_index = (
         
     | 
| 
       419 
521 
     | 
    
         
             
                            self.req_to_metadata_buffer_idx_allocator.alloc()
         
     | 
| 
       420 
522 
     | 
    
         
             
                        )
         
     | 
| 
       421 
523 
     | 
    
         
             
                        assert decode_req.metadata_buffer_index is not None
         
     | 
| 
       422 
     | 
    
         
            -
                        page_indices = kv_to_page_indices(
         
     | 
| 
       423 
     | 
    
         
            -
             
     | 
| 
      
 524 
     | 
    
         
            +
                        page_indices = kv_to_page_indices(kv_indices, page_size)
         
     | 
| 
      
 525 
     | 
    
         
            +
                        decode_req.kv_receiver.init(
         
     | 
| 
      
 526 
     | 
    
         
            +
                            page_indices, decode_req.metadata_buffer_index, state_indices
         
     | 
| 
       424 
527 
     | 
    
         
             
                        )
         
     | 
| 
       425 
     | 
    
         
            -
                        decode_req. 
     | 
| 
       426 
     | 
    
         
            -
             
     | 
| 
      
 528 
     | 
    
         
            +
                        decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
         
     | 
| 
       427 
529 
     | 
    
         
             
                        preallocated_reqs.append(decode_req)
         
     | 
| 
       428 
530 
     | 
    
         
             
                        indices_to_remove.add(i)
         
     | 
| 
       429 
531 
     | 
    
         
             
                        decode_req.req.time_stats.decode_transfer_queue_entry_time = (
         
     | 
| 
         @@ -503,7 +605,10 @@ class DecodePreallocQueue: 
     | 
|
| 
       503 
605 
     | 
    
         | 
| 
       504 
606 
     | 
    
         
             
                def _pre_alloc(self, req: Req) -> torch.Tensor:
         
     | 
| 
       505 
607 
     | 
    
         
             
                    """Pre-allocate the memory for req_to_token and token_kv_pool"""
         
     | 
| 
       506 
     | 
    
         
            -
                     
     | 
| 
      
 608 
     | 
    
         
            +
                    if isinstance(self.req_to_token_pool, HybridMambaDecodeReqToTokenPool):
         
     | 
| 
      
 609 
     | 
    
         
            +
                        req_pool_indices = self.req_to_token_pool.alloc(1, [req])
         
     | 
| 
      
 610 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 611 
     | 
    
         
            +
                        req_pool_indices = self.req_to_token_pool.alloc(1)
         
     | 
| 
       507 
612 
     | 
    
         | 
| 
       508 
613 
     | 
    
         
             
                    assert (
         
     | 
| 
       509 
614 
     | 
    
         
             
                        req_pool_indices is not None
         
     | 
| 
         @@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender): 
     | 
|
| 
       48 
48 
     | 
    
         
             
                def send(
         
     | 
| 
       49 
49 
     | 
    
         
             
                    self,
         
     | 
| 
       50 
50 
     | 
    
         
             
                    kv_indices: npt.NDArray[np.int32],
         
     | 
| 
      
 51 
     | 
    
         
            +
                    state_indices: Optional[List[int]] = None,
         
     | 
| 
       51 
52 
     | 
    
         
             
                ):
         
     | 
| 
       52 
53 
     | 
    
         
             
                    self.has_sent = True
         
     | 
| 
       53 
     | 
    
         
            -
                    logger.debug( 
     | 
| 
      
 54 
     | 
    
         
            +
                    logger.debug(
         
     | 
| 
      
 55 
     | 
    
         
            +
                        f"FakeKVSender send with kv_indices: {kv_indices}, state_indices: {state_indices}"
         
     | 
| 
      
 56 
     | 
    
         
            +
                    )
         
     | 
| 
       54 
57 
     | 
    
         | 
| 
       55 
58 
     | 
    
         
             
                def failure_exception(self):
         
     | 
| 
       56 
59 
     | 
    
         
             
                    raise Exception("Fake KVSender Exception")
         
     | 
| 
         @@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver): 
     | 
|
| 
       75 
78 
     | 
    
         
             
                        logger.debug("FakeKVReceiver poll success")
         
     | 
| 
       76 
79 
     | 
    
         
             
                        return KVPoll.Success
         
     | 
| 
       77 
80 
     | 
    
         | 
| 
       78 
     | 
    
         
            -
                def init( 
     | 
| 
      
 81 
     | 
    
         
            +
                def init(
         
     | 
| 
      
 82 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 83 
     | 
    
         
            +
                    kv_indices: list[int],
         
     | 
| 
      
 84 
     | 
    
         
            +
                    aux_index: Optional[int] = None,
         
     | 
| 
      
 85 
     | 
    
         
            +
                    state_indices: Optional[List[int]] = None,
         
     | 
| 
      
 86 
     | 
    
         
            +
                ):
         
     | 
| 
       79 
87 
     | 
    
         
             
                    self.has_init = True
         
     | 
| 
       80 
88 
     | 
    
         
             
                    logger.debug(
         
     | 
| 
       81 
     | 
    
         
            -
                        f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
         
     | 
| 
      
 89 
     | 
    
         
            +
                        f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
         
     | 
| 
       82 
90 
     | 
    
         
             
                    )
         
     | 
| 
       83 
91 
     | 
    
         | 
| 
       84 
92 
     | 
    
         
             
                def failure_exception(self):
         
     | 
| 
         @@ -58,6 +58,7 @@ class TransferKVChunk: 
     | 
|
| 
       58 
58 
     | 
    
         
             
                index_slice: slice
         
     | 
| 
       59 
59 
     | 
    
         
             
                is_last: bool
         
     | 
| 
       60 
60 
     | 
    
         
             
                prefill_aux_index: Optional[int]
         
     | 
| 
      
 61 
     | 
    
         
            +
                state_indices: Optional[List[int]]
         
     | 
| 
       61 
62 
     | 
    
         | 
| 
       62 
63 
     | 
    
         | 
| 
       63 
64 
     | 
    
         
             
            # decode
         
     | 
| 
         @@ -69,6 +70,7 @@ class TransferInfo: 
     | 
|
| 
       69 
70 
     | 
    
         
             
                mooncake_session_id: str
         
     | 
| 
       70 
71 
     | 
    
         
             
                dst_kv_indices: npt.NDArray[np.int32]
         
     | 
| 
       71 
72 
     | 
    
         
             
                dst_aux_index: int
         
     | 
| 
      
 73 
     | 
    
         
            +
                dst_state_indices: List[int]
         
     | 
| 
       72 
74 
     | 
    
         
             
                required_dst_info_num: int
         
     | 
| 
       73 
75 
     | 
    
         
             
                is_dummy: bool
         
     | 
| 
       74 
76 
     | 
    
         | 
| 
         @@ -78,9 +80,14 @@ class TransferInfo: 
     | 
|
| 
       78 
80 
     | 
    
         
             
                        is_dummy = True
         
     | 
| 
       79 
81 
     | 
    
         
             
                        dst_kv_indices = np.array([], dtype=np.int32)
         
     | 
| 
       80 
82 
     | 
    
         
             
                        dst_aux_index = None
         
     | 
| 
      
 83 
     | 
    
         
            +
                        dst_state_indices = []
         
     | 
| 
       81 
84 
     | 
    
         
             
                    else:
         
     | 
| 
       82 
85 
     | 
    
         
             
                        dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
         
     | 
| 
       83 
86 
     | 
    
         
             
                        dst_aux_index = int(msg[5].decode("ascii"))
         
     | 
| 
      
 87 
     | 
    
         
            +
                        if msg[6] == b"":
         
     | 
| 
      
 88 
     | 
    
         
            +
                            dst_state_indices = []
         
     | 
| 
      
 89 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 90 
     | 
    
         
            +
                            dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32))
         
     | 
| 
       84 
91 
     | 
    
         
             
                        is_dummy = False
         
     | 
| 
       85 
92 
     | 
    
         
             
                    return cls(
         
     | 
| 
       86 
93 
     | 
    
         
             
                        room=int(msg[0].decode("ascii")),
         
     | 
| 
         @@ -89,7 +96,8 @@ class TransferInfo: 
     | 
|
| 
       89 
96 
     | 
    
         
             
                        mooncake_session_id=msg[3].decode("ascii"),
         
     | 
| 
       90 
97 
     | 
    
         
             
                        dst_kv_indices=dst_kv_indices,
         
     | 
| 
       91 
98 
     | 
    
         
             
                        dst_aux_index=dst_aux_index,
         
     | 
| 
       92 
     | 
    
         
            -
                         
     | 
| 
      
 99 
     | 
    
         
            +
                        dst_state_indices=dst_state_indices,
         
     | 
| 
      
 100 
     | 
    
         
            +
                        required_dst_info_num=int(msg[7].decode("ascii")),
         
     | 
| 
       93 
101 
     | 
    
         
             
                        is_dummy=is_dummy,
         
     | 
| 
       94 
102 
     | 
    
         
             
                    )
         
     | 
| 
       95 
103 
     | 
    
         | 
| 
         @@ -103,6 +111,7 @@ class KVArgsRegisterInfo: 
     | 
|
| 
       103 
111 
     | 
    
         
             
                mooncake_session_id: str
         
     | 
| 
       104 
112 
     | 
    
         
             
                dst_kv_ptrs: list[int]
         
     | 
| 
       105 
113 
     | 
    
         
             
                dst_aux_ptrs: list[int]
         
     | 
| 
      
 114 
     | 
    
         
            +
                dst_state_data_ptrs: list[int]
         
     | 
| 
       106 
115 
     | 
    
         
             
                dst_tp_rank: int
         
     | 
| 
       107 
116 
     | 
    
         
             
                dst_attn_tp_size: int
         
     | 
| 
       108 
117 
     | 
    
         
             
                dst_kv_item_len: int
         
     | 
| 
         @@ -116,9 +125,10 @@ class KVArgsRegisterInfo: 
     | 
|
| 
       116 
125 
     | 
    
         
             
                        mooncake_session_id=msg[3].decode("ascii"),
         
     | 
| 
       117 
126 
     | 
    
         
             
                        dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
         
     | 
| 
       118 
127 
     | 
    
         
             
                        dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
         
     | 
| 
       119 
     | 
    
         
            -
                         
     | 
| 
       120 
     | 
    
         
            -
                         
     | 
| 
       121 
     | 
    
         
            -
                         
     | 
| 
      
 128 
     | 
    
         
            +
                        dst_state_data_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
         
     | 
| 
      
 129 
     | 
    
         
            +
                        dst_tp_rank=int(msg[7].decode("ascii")),
         
     | 
| 
      
 130 
     | 
    
         
            +
                        dst_attn_tp_size=int(msg[8].decode("ascii")),
         
     | 
| 
      
 131 
     | 
    
         
            +
                        dst_kv_item_len=int(msg[9].decode("ascii")),
         
     | 
| 
       122 
132 
     | 
    
         
             
                    )
         
     | 
| 
       123 
133 
     | 
    
         | 
| 
       124 
134 
     | 
    
         | 
| 
         @@ -180,6 +190,9 @@ class MooncakeKVManager(CommonKVManager): 
     | 
|
| 
       180 
190 
     | 
    
         
             
                            )
         
     | 
| 
       181 
191 
     | 
    
         
             
                            for _ in range(transfer_queue_size)
         
     | 
| 
       182 
192 
     | 
    
         
             
                        ]
         
     | 
| 
      
 193 
     | 
    
         
            +
                        self.state_executors = concurrent.futures.ThreadPoolExecutor(
         
     | 
| 
      
 194 
     | 
    
         
            +
                            transfer_thread_pool_size // transfer_queue_size
         
     | 
| 
      
 195 
     | 
    
         
            +
                        )
         
     | 
| 
       183 
196 
     | 
    
         
             
                        for queue, executor in zip(self.transfer_queues, self.executors):
         
     | 
| 
       184 
197 
     | 
    
         
             
                            threading.Thread(
         
     | 
| 
       185 
198 
     | 
    
         
             
                                target=self.transfer_worker, args=(queue, executor), daemon=True
         
     | 
| 
         @@ -239,6 +252,12 @@ class MooncakeKVManager(CommonKVManager): 
     | 
|
| 
       239 
252 
     | 
    
         
             
                            self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
         
     | 
| 
       240 
253 
     | 
    
         
             
                        )
         
     | 
| 
       241 
254 
     | 
    
         | 
| 
      
 255 
     | 
    
         
            +
                    # Batch register state/extra pool data buffers
         
     | 
| 
      
 256 
     | 
    
         
            +
                    if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
         
     | 
| 
      
 257 
     | 
    
         
            +
                        self.engine.batch_register(
         
     | 
| 
      
 258 
     | 
    
         
            +
                            self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
         
     | 
| 
      
 259 
     | 
    
         
            +
                        )
         
     | 
| 
      
 260 
     | 
    
         
            +
             
     | 
| 
       242 
261 
     | 
    
         
             
                def _transfer_data(self, mooncake_session_id, transfer_blocks):
         
     | 
| 
       243 
262 
     | 
    
         
             
                    if not transfer_blocks:
         
     | 
| 
       244 
263 
     | 
    
         
             
                        return 0
         
     | 
| 
         @@ -248,17 +267,23 @@ class MooncakeKVManager(CommonKVManager): 
     | 
|
| 
       248 
267 
     | 
    
         
             
                        mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
         
     | 
| 
       249 
268 
     | 
    
         
             
                    )
         
     | 
| 
       250 
269 
     | 
    
         | 
| 
       251 
     | 
    
         
            -
                def  
     | 
| 
      
 270 
     | 
    
         
            +
                def _send_kvcache_generic(
         
     | 
| 
       252 
271 
     | 
    
         
             
                    self,
         
     | 
| 
       253 
272 
     | 
    
         
             
                    mooncake_session_id: str,
         
     | 
| 
       254 
     | 
    
         
            -
                     
     | 
| 
       255 
     | 
    
         
            -
                     
     | 
| 
       256 
     | 
    
         
            -
                     
     | 
| 
      
 273 
     | 
    
         
            +
                    src_data_ptrs: list[int],
         
     | 
| 
      
 274 
     | 
    
         
            +
                    dst_data_ptrs: list[int],
         
     | 
| 
      
 275 
     | 
    
         
            +
                    item_lens: list[int],
         
     | 
| 
      
 276 
     | 
    
         
            +
                    prefill_data_indices: npt.NDArray[np.int32],
         
     | 
| 
      
 277 
     | 
    
         
            +
                    dst_data_indices: npt.NDArray[np.int32],
         
     | 
| 
       257 
278 
     | 
    
         
             
                    executor: concurrent.futures.ThreadPoolExecutor,
         
     | 
| 
       258 
     | 
    
         
            -
                ):
         
     | 
| 
       259 
     | 
    
         
            -
                     
     | 
| 
      
 279 
     | 
    
         
            +
                ) -> int:
         
     | 
| 
      
 280 
     | 
    
         
            +
                    """
         
     | 
| 
      
 281 
     | 
    
         
            +
                    Generic KV cache transfer supporting both MHA and MLA architectures.
         
     | 
| 
      
 282 
     | 
    
         
            +
                    This method is used by both send_kvcache (full pool) and maybe_send_extra.
         
     | 
| 
      
 283 
     | 
    
         
            +
                    """
         
     | 
| 
      
 284 
     | 
    
         
            +
                    # Group by indices for optimization
         
     | 
| 
       260 
285 
     | 
    
         
             
                    prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
         
     | 
| 
       261 
     | 
    
         
            -
                         
     | 
| 
      
 286 
     | 
    
         
            +
                        prefill_data_indices, dst_data_indices
         
     | 
| 
       262 
287 
     | 
    
         
             
                    )
         
     | 
| 
       263 
288 
     | 
    
         | 
| 
       264 
289 
     | 
    
         
             
                    layers_params = None
         
     | 
| 
         @@ -266,9 +291,9 @@ class MooncakeKVManager(CommonKVManager): 
     | 
|
| 
       266 
291 
     | 
    
         
             
                    # pp is not supported on the decode side yet
         
     | 
| 
       267 
292 
     | 
    
         
             
                    if self.is_mla_backend:
         
     | 
| 
       268 
293 
     | 
    
         
             
                        src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
         
     | 
| 
       269 
     | 
    
         
            -
                            self.get_mla_kv_ptrs_with_pp( 
     | 
| 
      
 294 
     | 
    
         
            +
                            self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
         
     | 
| 
       270 
295 
     | 
    
         
             
                        )
         
     | 
| 
       271 
     | 
    
         
            -
                        kv_item_len =  
     | 
| 
      
 296 
     | 
    
         
            +
                        kv_item_len = item_lens[0]
         
     | 
| 
       272 
297 
     | 
    
         
             
                        layers_params = [
         
     | 
| 
       273 
298 
     | 
    
         
             
                            (
         
     | 
| 
       274 
299 
     | 
    
         
             
                                src_kv_ptrs[layer_id],
         
     | 
| 
         @@ -279,9 +304,9 @@ class MooncakeKVManager(CommonKVManager): 
     | 
|
| 
       279 
304 
     | 
    
         
             
                        ]
         
     | 
| 
       280 
305 
     | 
    
         
             
                    else:
         
     | 
| 
       281 
306 
     | 
    
         
             
                        src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
         
     | 
| 
       282 
     | 
    
         
            -
                            self.get_mha_kv_ptrs_with_pp( 
     | 
| 
      
 307 
     | 
    
         
            +
                            self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
         
     | 
| 
       283 
308 
     | 
    
         
             
                        )
         
     | 
| 
       284 
     | 
    
         
            -
                        kv_item_len =  
     | 
| 
      
 309 
     | 
    
         
            +
                        kv_item_len = item_lens[0]
         
     | 
| 
       285 
310 
     | 
    
         
             
                        layers_params = [
         
     | 
| 
       286 
311 
     | 
    
         
             
                            (
         
     | 
| 
       287 
312 
     | 
    
         
             
                                src_k_ptrs[layer_id],
         
     | 
| 
         @@ -345,6 +370,24 @@ class MooncakeKVManager(CommonKVManager): 
     | 
|
| 
       345 
370 
     | 
    
         | 
| 
       346 
371 
     | 
    
         
             
                    return 0
         
     | 
| 
       347 
372 
     | 
    
         | 
| 
      
 373 
     | 
    
         
            +
                def send_kvcache(
         
     | 
| 
      
 374 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 375 
     | 
    
         
            +
                    mooncake_session_id: str,
         
     | 
| 
      
 376 
     | 
    
         
            +
                    prefill_kv_indices: npt.NDArray[np.int32],
         
     | 
| 
      
 377 
     | 
    
         
            +
                    dst_kv_ptrs: list[int],
         
     | 
| 
      
 378 
     | 
    
         
            +
                    dst_kv_indices: npt.NDArray[np.int32],
         
     | 
| 
      
 379 
     | 
    
         
            +
                    executor: concurrent.futures.ThreadPoolExecutor,
         
     | 
| 
      
 380 
     | 
    
         
            +
                ):
         
     | 
| 
      
 381 
     | 
    
         
            +
                    return self._send_kvcache_generic(
         
     | 
| 
      
 382 
     | 
    
         
            +
                        mooncake_session_id=mooncake_session_id,
         
     | 
| 
      
 383 
     | 
    
         
            +
                        src_data_ptrs=self.kv_args.kv_data_ptrs,
         
     | 
| 
      
 384 
     | 
    
         
            +
                        dst_data_ptrs=dst_kv_ptrs,
         
     | 
| 
      
 385 
     | 
    
         
            +
                        item_lens=self.kv_args.kv_item_lens,
         
     | 
| 
      
 386 
     | 
    
         
            +
                        prefill_data_indices=prefill_kv_indices,
         
     | 
| 
      
 387 
     | 
    
         
            +
                        dst_data_indices=dst_kv_indices,
         
     | 
| 
      
 388 
     | 
    
         
            +
                        executor=executor,
         
     | 
| 
      
 389 
     | 
    
         
            +
                    )
         
     | 
| 
      
 390 
     | 
    
         
            +
             
     | 
| 
       348 
391 
     | 
    
         
             
                def send_kvcache_slice(
         
     | 
| 
       349 
392 
     | 
    
         
             
                    self,
         
     | 
| 
       350 
393 
     | 
    
         
             
                    mooncake_session_id: str,
         
     | 
| 
         @@ -593,6 +636,58 @@ class MooncakeKVManager(CommonKVManager): 
     | 
|
| 
       593 
636 
     | 
    
         
             
                        f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
         
     | 
| 
       594 
637 
     | 
    
         
             
                    )
         
     | 
| 
       595 
638 
     | 
    
         | 
| 
      
 639 
     | 
    
         
            +
                def maybe_send_extra(
         
     | 
| 
      
 640 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 641 
     | 
    
         
            +
                    req: TransferInfo,
         
     | 
| 
      
 642 
     | 
    
         
            +
                    prefill_state_indices: list[int],
         
     | 
| 
      
 643 
     | 
    
         
            +
                    dst_state_data_ptrs: list[int],
         
     | 
| 
      
 644 
     | 
    
         
            +
                ):
         
     | 
| 
      
 645 
     | 
    
         
            +
                    """Send state or extra pool data with type-specific handling."""
         
     | 
| 
      
 646 
     | 
    
         
            +
                    state_type = getattr(self.kv_args, "state_type", "none")
         
     | 
| 
      
 647 
     | 
    
         
            +
             
     | 
| 
      
 648 
     | 
    
         
            +
                    if state_type == "mamba":
         
     | 
| 
      
 649 
     | 
    
         
            +
                        return self._send_mamba_state(
         
     | 
| 
      
 650 
     | 
    
         
            +
                            req,
         
     | 
| 
      
 651 
     | 
    
         
            +
                            prefill_state_indices,
         
     | 
| 
      
 652 
     | 
    
         
            +
                            dst_state_data_ptrs,
         
     | 
| 
      
 653 
     | 
    
         
            +
                        )
         
     | 
| 
      
 654 
     | 
    
         
            +
                    elif state_type in ["swa", "nsa"]:
         
     | 
| 
      
 655 
     | 
    
         
            +
                        # Reuse _send_kvcache_generic interface to send extra pool data
         
     | 
| 
      
 656 
     | 
    
         
            +
                        prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32)
         
     | 
| 
      
 657 
     | 
    
         
            +
                        dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32)
         
     | 
| 
      
 658 
     | 
    
         
            +
                        return self._send_kvcache_generic(
         
     | 
| 
      
 659 
     | 
    
         
            +
                            mooncake_session_id=req.mooncake_session_id,
         
     | 
| 
      
 660 
     | 
    
         
            +
                            src_data_ptrs=self.kv_args.state_data_ptrs,
         
     | 
| 
      
 661 
     | 
    
         
            +
                            dst_data_ptrs=dst_state_data_ptrs,
         
     | 
| 
      
 662 
     | 
    
         
            +
                            item_lens=self.kv_args.state_item_lens,
         
     | 
| 
      
 663 
     | 
    
         
            +
                            prefill_data_indices=prefill_state_indices,
         
     | 
| 
      
 664 
     | 
    
         
            +
                            dst_data_indices=dst_state_indices,
         
     | 
| 
      
 665 
     | 
    
         
            +
                            executor=self.state_executors,
         
     | 
| 
      
 666 
     | 
    
         
            +
                        )
         
     | 
| 
      
 667 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 668 
     | 
    
         
            +
                        return 0
         
     | 
| 
      
 669 
     | 
    
         
            +
             
     | 
| 
      
 670 
     | 
    
         
            +
                def _send_mamba_state(
         
     | 
| 
      
 671 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 672 
     | 
    
         
            +
                    req: TransferInfo,
         
     | 
| 
      
 673 
     | 
    
         
            +
                    prefill_mamba_index: list[int],
         
     | 
| 
      
 674 
     | 
    
         
            +
                    dst_state_data_ptrs: list[int],
         
     | 
| 
      
 675 
     | 
    
         
            +
                ):
         
     | 
| 
      
 676 
     | 
    
         
            +
                    """Transfer Mamba states."""
         
     | 
| 
      
 677 
     | 
    
         
            +
                    assert len(prefill_mamba_index) == 1, "Mamba should have single state index"
         
     | 
| 
      
 678 
     | 
    
         
            +
             
     | 
| 
      
 679 
     | 
    
         
            +
                    transfer_blocks = []
         
     | 
| 
      
 680 
     | 
    
         
            +
                    prefill_state_data_ptrs = self.kv_args.state_data_ptrs
         
     | 
| 
      
 681 
     | 
    
         
            +
                    prefill_state_item_lens = self.kv_args.state_item_lens
         
     | 
| 
      
 682 
     | 
    
         
            +
             
     | 
| 
      
 683 
     | 
    
         
            +
                    for i, dst_state_ptr in enumerate(dst_state_data_ptrs):
         
     | 
| 
      
 684 
     | 
    
         
            +
                        length = prefill_state_item_lens[i]
         
     | 
| 
      
 685 
     | 
    
         
            +
                        src_addr = prefill_state_data_ptrs[i] + length * int(prefill_mamba_index[0])
         
     | 
| 
      
 686 
     | 
    
         
            +
                        dst_addr = dst_state_ptr + length * int(req.dst_state_indices[0])
         
     | 
| 
      
 687 
     | 
    
         
            +
                        transfer_blocks.append((src_addr, dst_addr, length))
         
     | 
| 
      
 688 
     | 
    
         
            +
             
     | 
| 
      
 689 
     | 
    
         
            +
                    return self._transfer_data(req.mooncake_session_id, transfer_blocks)
         
     | 
| 
      
 690 
     | 
    
         
            +
             
     | 
| 
       596 
691 
     | 
    
         
             
                def sync_status_to_decode_endpoint(
         
     | 
| 
       597 
692 
     | 
    
         
             
                    self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
         
     | 
| 
       598 
693 
     | 
    
         
             
                ):
         
     | 
| 
         @@ -702,6 +797,21 @@ class MooncakeKVManager(CommonKVManager): 
     | 
|
| 
       702 
797 
     | 
    
         
             
                                        break
         
     | 
| 
       703 
798 
     | 
    
         | 
| 
       704 
799 
     | 
    
         
             
                                    if kv_chunk.is_last:
         
     | 
| 
      
 800 
     | 
    
         
            +
                                        if kv_chunk.state_indices is not None:
         
     | 
| 
      
 801 
     | 
    
         
            +
                                            if not self.is_mla_backend and (
         
     | 
| 
      
 802 
     | 
    
         
            +
                                                self.attn_tp_size
         
     | 
| 
      
 803 
     | 
    
         
            +
                                                != target_rank_registration_info.dst_attn_tp_size
         
     | 
| 
      
 804 
     | 
    
         
            +
                                            ):
         
     | 
| 
      
 805 
     | 
    
         
            +
                                                raise RuntimeError(
         
     | 
| 
      
 806 
     | 
    
         
            +
                                                    f"PD Disaggregation does NOT support PD different TP sizes for non-MLA hybrid models yet."
         
     | 
| 
      
 807 
     | 
    
         
            +
                                                )
         
     | 
| 
      
 808 
     | 
    
         
            +
             
     | 
| 
      
 809 
     | 
    
         
            +
                                            self.maybe_send_extra(
         
     | 
| 
      
 810 
     | 
    
         
            +
                                                req,
         
     | 
| 
      
 811 
     | 
    
         
            +
                                                kv_chunk.state_indices,
         
     | 
| 
      
 812 
     | 
    
         
            +
                                                target_rank_registration_info.dst_state_data_ptrs,
         
     | 
| 
      
 813 
     | 
    
         
            +
                                            )
         
     | 
| 
      
 814 
     | 
    
         
            +
             
     | 
| 
       705 
815 
     | 
    
         
             
                                        if self.pp_group.is_last_rank:
         
     | 
| 
       706 
816 
     | 
    
         
             
                                            # Only the last chunk we need to send the aux data
         
     | 
| 
       707 
817 
     | 
    
         
             
                                            ret = self.send_aux(
         
     | 
| 
         @@ -765,7 +875,7 @@ class MooncakeKVManager(CommonKVManager): 
     | 
|
| 
       765 
875 
     | 
    
         
             
                                )
         
     | 
| 
       766 
876 
     | 
    
         
             
                                continue
         
     | 
| 
       767 
877 
     | 
    
         
             
                            else:
         
     | 
| 
       768 
     | 
    
         
            -
                                required_dst_info_num = int(waiting_req_bytes[ 
     | 
| 
      
 878 
     | 
    
         
            +
                                required_dst_info_num = int(waiting_req_bytes[7].decode("ascii"))
         
     | 
| 
       769 
879 
     | 
    
         
             
                                room = int(room)
         
     | 
| 
       770 
880 
     | 
    
         
             
                                if room not in self.transfer_infos:
         
     | 
| 
       771 
881 
     | 
    
         
             
                                    self.transfer_infos[room] = {}
         
     | 
| 
         @@ -876,6 +986,7 @@ class MooncakeKVManager(CommonKVManager): 
     | 
|
| 
       876 
986 
     | 
    
         
             
                    index_slice: slice,
         
     | 
| 
       877 
987 
     | 
    
         
             
                    is_last: bool,
         
     | 
| 
       878 
988 
     | 
    
         
             
                    aux_index: Optional[int] = None,
         
     | 
| 
      
 989 
     | 
    
         
            +
                    state_indices: Optional[List[int]] = None,
         
     | 
| 
       879 
990 
     | 
    
         
             
                ):
         
     | 
| 
       880 
991 
     | 
    
         
             
                    assert self.disaggregation_mode == DisaggregationMode.PREFILL
         
     | 
| 
       881 
992 
     | 
    
         
             
                    assert not is_last or (is_last and aux_index is not None)
         
     | 
| 
         @@ -909,6 +1020,7 @@ class MooncakeKVManager(CommonKVManager): 
     | 
|
| 
       909 
1020 
     | 
    
         
             
                            index_slice=index_slice,
         
     | 
| 
       910 
1021 
     | 
    
         
             
                            is_last=is_last,
         
     | 
| 
       911 
1022 
     | 
    
         
             
                            prefill_aux_index=aux_index,
         
     | 
| 
      
 1023 
     | 
    
         
            +
                            state_indices=state_indices,
         
     | 
| 
       912 
1024 
     | 
    
         
             
                        )
         
     | 
| 
       913 
1025 
     | 
    
         
             
                    )
         
     | 
| 
       914 
1026 
     | 
    
         | 
| 
         @@ -989,6 +1101,7 @@ class MooncakeKVSender(CommonKVSender): 
     | 
|
| 
       989 
1101 
     | 
    
         
             
                def send(
         
     | 
| 
       990 
1102 
     | 
    
         
             
                    self,
         
     | 
| 
       991 
1103 
     | 
    
         
             
                    kv_indices: npt.NDArray[np.int32],
         
     | 
| 
      
 1104 
     | 
    
         
            +
                    state_indices: Optional[List[int]] = None,
         
     | 
| 
       992 
1105 
     | 
    
         
             
                ):
         
     | 
| 
       993 
1106 
     | 
    
         
             
                    index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
         
     | 
| 
       994 
1107 
     | 
    
         
             
                    self.curr_idx += len(kv_indices)
         
     | 
| 
         @@ -1008,6 +1121,7 @@ class MooncakeKVSender(CommonKVSender): 
     | 
|
| 
       1008 
1121 
     | 
    
         
             
                            index_slice,
         
     | 
| 
       1009 
1122 
     | 
    
         
             
                            True,
         
     | 
| 
       1010 
1123 
     | 
    
         
             
                            aux_index=self.aux_index,
         
     | 
| 
      
 1124 
     | 
    
         
            +
                            state_indices=state_indices,
         
     | 
| 
       1011 
1125 
     | 
    
         
             
                        )
         
     | 
| 
       1012 
1126 
     | 
    
         | 
| 
       1013 
1127 
     | 
    
         
             
                def poll(self) -> KVPoll:
         
     | 
| 
         @@ -1110,6 +1224,9 @@ class MooncakeKVReceiver(CommonKVReceiver): 
     | 
|
| 
       1110 
1224 
     | 
    
         
             
                        packed_aux_data_ptrs = b"".join(
         
     | 
| 
       1111 
1225 
     | 
    
         
             
                            struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
         
     | 
| 
       1112 
1226 
     | 
    
         
             
                        )
         
     | 
| 
      
 1227 
     | 
    
         
            +
                        packed_state_data_ptrs = b"".join(
         
     | 
| 
      
 1228 
     | 
    
         
            +
                            struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs
         
     | 
| 
      
 1229 
     | 
    
         
            +
                        )
         
     | 
| 
       1113 
1230 
     | 
    
         
             
                        # Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
         
     | 
| 
       1114 
1231 
     | 
    
         
             
                        tp_rank = self.kv_mgr.kv_args.engine_rank
         
     | 
| 
       1115 
1232 
     | 
    
         
             
                        kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
         
     | 
| 
         @@ -1127,13 +1244,19 @@ class MooncakeKVReceiver(CommonKVReceiver): 
     | 
|
| 
       1127 
1244 
     | 
    
         
             
                                    self.session_id.encode("ascii"),
         
     | 
| 
       1128 
1245 
     | 
    
         
             
                                    packed_kv_data_ptrs,
         
     | 
| 
       1129 
1246 
     | 
    
         
             
                                    packed_aux_data_ptrs,
         
     | 
| 
      
 1247 
     | 
    
         
            +
                                    packed_state_data_ptrs,
         
     | 
| 
       1130 
1248 
     | 
    
         
             
                                    dst_tp_rank,
         
     | 
| 
       1131 
1249 
     | 
    
         
             
                                    dst_attn_tp_size,
         
     | 
| 
       1132 
1250 
     | 
    
         
             
                                    dst_kv_item_len,
         
     | 
| 
       1133 
1251 
     | 
    
         
             
                                ]
         
     | 
| 
       1134 
1252 
     | 
    
         
             
                            )
         
     | 
| 
       1135 
1253 
     | 
    
         | 
| 
       1136 
     | 
    
         
            -
                def init( 
     | 
| 
      
 1254 
     | 
    
         
            +
                def init(
         
     | 
| 
      
 1255 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 1256 
     | 
    
         
            +
                    kv_indices: npt.NDArray[np.int32],
         
     | 
| 
      
 1257 
     | 
    
         
            +
                    aux_index: Optional[int] = None,
         
     | 
| 
      
 1258 
     | 
    
         
            +
                    state_indices: Optional[List[int]] = None,
         
     | 
| 
      
 1259 
     | 
    
         
            +
                ):
         
     | 
| 
       1137 
1260 
     | 
    
         
             
                    for bootstrap_info in self.bootstrap_infos:
         
     | 
| 
       1138 
1261 
     | 
    
         
             
                        sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
         
     | 
| 
       1139 
1262 
     | 
    
         
             
                        is_dummy = bootstrap_info["is_dummy"]
         
     | 
| 
         @@ -1147,6 +1270,14 @@ class MooncakeKVReceiver(CommonKVReceiver): 
     | 
|
| 
       1147 
1270 
     | 
    
         
             
                                    self.session_id.encode("ascii"),
         
     | 
| 
       1148 
1271 
     | 
    
         
             
                                    kv_indices.tobytes() if not is_dummy else b"",
         
     | 
| 
       1149 
1272 
     | 
    
         
             
                                    str(aux_index).encode("ascii") if not is_dummy else b"",
         
     | 
| 
      
 1273 
     | 
    
         
            +
                                    (
         
     | 
| 
      
 1274 
     | 
    
         
            +
                                        np.array(
         
     | 
| 
      
 1275 
     | 
    
         
            +
                                            state_indices,
         
     | 
| 
      
 1276 
     | 
    
         
            +
                                            dtype=np.int32,
         
     | 
| 
      
 1277 
     | 
    
         
            +
                                        ).tobytes()
         
     | 
| 
      
 1278 
     | 
    
         
            +
                                        if not is_dummy and state_indices is not None
         
     | 
| 
      
 1279 
     | 
    
         
            +
                                        else b""
         
     | 
| 
      
 1280 
     | 
    
         
            +
                                    ),
         
     | 
| 
       1150 
1281 
     | 
    
         
             
                                    str(self.required_dst_info_num).encode("ascii"),
         
     | 
| 
       1151 
1282 
     | 
    
         
             
                                ]
         
     | 
| 
       1152 
1283 
     | 
    
         
             
                            )
         
     | 
| 
         @@ -704,6 +704,7 @@ class NixlKVSender(CommonKVSender): 
     | 
|
| 
       704 
704 
     | 
    
         
             
                def send(
         
     | 
| 
       705 
705 
     | 
    
         
             
                    self,
         
     | 
| 
       706 
706 
     | 
    
         
             
                    kv_indices: npt.NDArray[np.int32],
         
     | 
| 
      
 707 
     | 
    
         
            +
                    state_indices: Optional[List[int]] = None,
         
     | 
| 
       707 
708 
     | 
    
         
             
                ):
         
     | 
| 
       708 
709 
     | 
    
         
             
                    index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
         
     | 
| 
       709 
710 
     | 
    
         
             
                    self.curr_idx += len(kv_indices)
         
     | 
| 
         @@ -755,7 +756,12 @@ class NixlKVReceiver(CommonKVReceiver): 
     | 
|
| 
       755 
756 
     | 
    
         
             
                            self.bootstrap_room
         
     | 
| 
       756 
757 
     | 
    
         
             
                        )
         
     | 
| 
       757 
758 
     | 
    
         | 
| 
       758 
     | 
    
         
            -
                def init( 
     | 
| 
      
 759 
     | 
    
         
            +
                def init(
         
     | 
| 
      
 760 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 761 
     | 
    
         
            +
                    kv_indices: npt.NDArray[np.int32],
         
     | 
| 
      
 762 
     | 
    
         
            +
                    aux_index: Optional[int] = None,
         
     | 
| 
      
 763 
     | 
    
         
            +
                    state_indices: Optional[List[int]] = None,
         
     | 
| 
      
 764 
     | 
    
         
            +
                ):
         
     | 
| 
       759 
765 
     | 
    
         
             
                    for bootstrap_info in self.bootstrap_infos:
         
     | 
| 
       760 
766 
     | 
    
         
             
                        logger.debug(
         
     | 
| 
       761 
767 
     | 
    
         
             
                            f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
         
     |