sglang 0.5.3.post2__py3-none-any.whl → 0.5.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. sglang/bench_one_batch.py +13 -8
  2. sglang/srt/disaggregation/base/conn.py +17 -4
  3. sglang/srt/disaggregation/common/conn.py +1 -0
  4. sglang/srt/disaggregation/decode.py +113 -8
  5. sglang/srt/disaggregation/fake/conn.py +11 -3
  6. sglang/srt/disaggregation/mooncake/conn.py +148 -17
  7. sglang/srt/disaggregation/nixl/conn.py +7 -1
  8. sglang/srt/disaggregation/prefill.py +71 -1
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -3
  10. sglang/srt/environ.py +3 -3
  11. sglang/srt/layers/attention/ascend_backend.py +17 -0
  12. sglang/srt/layers/layernorm.py +41 -9
  13. sglang/srt/layers/logits_processor.py +1 -1
  14. sglang/srt/layers/moe/utils.py +4 -2
  15. sglang/srt/layers/rotary_embedding.py +16 -2
  16. sglang/srt/layers/sampler.py +3 -3
  17. sglang/srt/managers/scheduler.py +0 -6
  18. sglang/srt/mem_cache/allocator_ascend.py +1 -1
  19. sglang/srt/mem_cache/common.py +1 -5
  20. sglang/srt/mem_cache/memory_pool.py +248 -137
  21. sglang/srt/model_executor/model_runner.py +28 -13
  22. sglang/srt/model_executor/npu_graph_runner.py +2 -2
  23. sglang/srt/model_loader/weight_utils.py +2 -2
  24. sglang/srt/models/deepseek_v2.py +1 -0
  25. sglang/srt/models/glm4_moe.py +4 -2
  26. sglang/srt/server_args.py +31 -9
  27. sglang/srt/speculative/eagle_worker.py +2 -2
  28. sglang/srt/speculative/spec_info.py +2 -0
  29. sglang/srt/speculative/standalone_worker.py +1 -1
  30. sglang/test/runners.py +1 -1
  31. sglang/test/send_one.py +27 -1
  32. sglang/test/test_disaggregation_utils.py +33 -15
  33. sglang/test/test_utils.py +37 -2
  34. sglang/version.py +1 -1
  35. {sglang-0.5.3.post2.dist-info → sglang-0.5.3.post3.dist-info}/METADATA +1 -1
  36. {sglang-0.5.3.post2.dist-info → sglang-0.5.3.post3.dist-info}/RECORD +39 -39
  37. {sglang-0.5.3.post2.dist-info → sglang-0.5.3.post3.dist-info}/WHEEL +0 -0
  38. {sglang-0.5.3.post2.dist-info → sglang-0.5.3.post3.dist-info}/licenses/LICENSE +0 -0
  39. {sglang-0.5.3.post2.dist-info → sglang-0.5.3.post3.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -72,6 +72,8 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
72
72
  from sglang.srt.utils import (
73
73
  configure_logger,
74
74
  get_bool_env_var,
75
+ is_cuda_alike,
76
+ is_xpu,
75
77
  kill_process_tree,
76
78
  require_mlp_sync,
77
79
  require_mlp_tp_gather,
@@ -80,6 +82,15 @@ from sglang.srt.utils import (
80
82
  )
81
83
  from sglang.srt.utils.hf_transformers_utils import get_tokenizer
82
84
 
85
+ profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
86
+ profiler_activity
87
+ for available, profiler_activity in [
88
+ (is_cuda_alike(), torch.profiler.ProfilerActivity.CUDA),
89
+ (is_xpu(), torch.profiler.ProfilerActivity.XPU),
90
+ ]
91
+ if available
92
+ ]
93
+
83
94
 
84
95
  @dataclasses.dataclass
85
96
  class BenchArgs:
@@ -424,10 +435,7 @@ def latency_test_run_once(
424
435
  profiler = None
425
436
  if profile:
426
437
  profiler = torch.profiler.profile(
427
- activities=[
428
- torch.profiler.ProfilerActivity.CPU,
429
- torch.profiler.ProfilerActivity.CUDA,
430
- ],
438
+ activities=profile_activities,
431
439
  with_stack=True,
432
440
  record_shapes=profile_record_shapes,
433
441
  )
@@ -460,10 +468,7 @@ def latency_test_run_once(
460
468
  if profile and i == output_len / 2:
461
469
  profiler = None
462
470
  profiler = torch.profiler.profile(
463
- activities=[
464
- torch.profiler.ProfilerActivity.CPU,
465
- torch.profiler.ProfilerActivity.CUDA,
466
- ],
471
+ activities=profile_activities,
467
472
  with_stack=True,
468
473
  record_shapes=profile_record_shapes,
469
474
  )
@@ -20,6 +20,10 @@ class KVArgs:
20
20
  aux_data_ptrs: List[int]
21
21
  aux_data_lens: List[int]
22
22
  aux_item_lens: List[int]
23
+ state_data_ptrs: List[int]
24
+ state_data_lens: List[int]
25
+ state_item_lens: List[int]
26
+ state_type: str # "none", "mamba", "swa"
23
27
  ib_device: str
24
28
  ib_traffic_class: str
25
29
  gpu_id: int
@@ -76,9 +80,13 @@ class BaseKVSender(ABC):
76
80
  ...
77
81
 
78
82
  @abstractmethod
79
- def send(self, kv_indices: npt.NDArray[np.int32]):
83
+ def send(
84
+ self,
85
+ kv_indices: npt.NDArray[np.int32],
86
+ state_indices: Optional[List[int]] = None,
87
+ ):
80
88
  """
81
- Send the kv cache at the given kv indices to the decoder server
89
+ Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server
82
90
  """
83
91
  ...
84
92
 
@@ -108,9 +116,14 @@ class BaseKVReceiver(ABC):
108
116
  ): ...
109
117
 
110
118
  @abstractmethod
111
- def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
119
+ def init(
120
+ self,
121
+ kv_indices: npt.NDArray[np.int32],
122
+ aux_index: Optional[int] = None,
123
+ state_indices: Optional[List[int]] = None,
124
+ ):
112
125
  """
113
- Notify the prefill server about the kv indices and aux index
126
+ Notify the prefill server about the kv indices, aux index, and state_indices.
114
127
  """
115
128
  ...
116
129
 
@@ -201,6 +201,7 @@ class CommonKVSender(BaseKVSender):
201
201
  def send(
202
202
  self,
203
203
  kv_indices: npt.NDArray[np.int32],
204
+ state_indices: Optional[List[int]] = None,
204
205
  ):
205
206
  pass
206
207
 
@@ -25,11 +25,12 @@ import time
25
25
  from collections import deque
26
26
  from dataclasses import dataclass
27
27
  from http import HTTPStatus
28
- from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
28
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
29
29
 
30
30
  import torch
31
31
  from torch.distributed import ProcessGroup
32
32
 
33
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams
33
34
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
34
35
  from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
35
36
  from sglang.srt.disaggregation.utils import (
@@ -47,9 +48,19 @@ from sglang.srt.disaggregation.utils import (
47
48
  )
48
49
  from sglang.srt.layers.dp_attention import get_attention_tp_size
49
50
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
50
- from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
51
+ from sglang.srt.mem_cache.allocator import (
52
+ BaseTokenToKVPoolAllocator,
53
+ SWATokenToKVPoolAllocator,
54
+ )
51
55
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
52
- from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
56
+ from sglang.srt.mem_cache.memory_pool import (
57
+ HybridLinearKVPool,
58
+ HybridReqToTokenPool,
59
+ KVCache,
60
+ NSATokenToKVPool,
61
+ ReqToTokenPool,
62
+ SWAKVPool,
63
+ )
53
64
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
54
65
  from sglang.srt.utils import get_int_env_var, require_mlp_sync
55
66
  from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -124,6 +135,35 @@ class DecodeReqToTokenPool:
124
135
  self.free_slots = list(range(self.size + self.pre_alloc_size))
125
136
 
126
137
 
138
+ class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
139
+
140
+ def __init__(
141
+ self,
142
+ size: int,
143
+ max_context_len: int,
144
+ device: str,
145
+ enable_memory_saver: bool,
146
+ cache_params: "Mamba2CacheParams",
147
+ speculative_num_draft_tokens: int,
148
+ pre_alloc_size: int,
149
+ ):
150
+ DecodeReqToTokenPool.__init__(
151
+ self,
152
+ size=size,
153
+ max_context_len=max_context_len,
154
+ device=device,
155
+ enable_memory_saver=enable_memory_saver,
156
+ pre_alloc_size=pre_alloc_size,
157
+ )
158
+ self._init_mamba_pool(
159
+ size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens
160
+ )
161
+
162
+ def clear(self):
163
+ self.free_slots = list(range(self.size + self.pre_alloc_size))
164
+ self.mamba_pool.clear()
165
+
166
+
127
167
  @dataclass
128
168
  class DecodeRequest:
129
169
  req: Req
@@ -217,6 +257,28 @@ class DecodePreallocQueue:
217
257
  self.metadata_buffers.get_buf_infos()
218
258
  )
219
259
 
260
+ if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
261
+ state_data_ptrs, state_data_lens, state_item_lens = (
262
+ self.token_to_kv_pool.get_state_buf_infos()
263
+ )
264
+ kv_args.state_data_ptrs = state_data_ptrs
265
+ kv_args.state_data_lens = state_data_lens
266
+ kv_args.state_item_lens = state_item_lens
267
+
268
+ if isinstance(self.token_to_kv_pool, SWAKVPool):
269
+ kv_args.state_type = "swa"
270
+ elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
271
+ kv_args.state_type = "mamba"
272
+ elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
273
+ kv_args.state_type = "nsa"
274
+ else:
275
+ kv_args.state_type = "none"
276
+ else:
277
+ kv_args.state_data_ptrs = []
278
+ kv_args.state_data_lens = []
279
+ kv_args.state_item_lens = []
280
+ kv_args.state_type = "none"
281
+
220
282
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
221
283
  kv_args.gpu_id = self.scheduler.gpu_id
222
284
  kv_manager_class: Type[BaseKVManager] = get_kv_class(
@@ -414,16 +476,56 @@ class DecodePreallocQueue:
414
476
  .cpu()
415
477
  .numpy()
416
478
  )
479
+ page_size = self.token_to_kv_pool_allocator.page_size
480
+
481
+ # Prepare extra pool indices for hybrid models
482
+ if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
483
+ # Mamba hybrid model: single mamba state index
484
+ state_indices = [
485
+ self.req_to_token_pool.req_index_to_mamba_index_mapping[
486
+ decode_req.req.req_pool_idx
487
+ ]
488
+ .cpu()
489
+ .numpy()
490
+ ]
491
+ elif isinstance(self.token_to_kv_pool, SWAKVPool):
492
+ # SWA hybrid model: send decode-side SWA window indices
493
+ seq_len = len(decode_req.req.origin_input_ids)
494
+ window_size = self.scheduler.sliding_window_size
495
+
496
+ window_start = max(0, seq_len - window_size)
497
+ window_start = (window_start // page_size) * page_size
498
+ window_kv_indices_full = self.req_to_token_pool.req_to_token[
499
+ decode_req.req.req_pool_idx, window_start:seq_len
500
+ ]
501
+
502
+ # Translate to SWA pool indices
503
+ window_kv_indices_swa = (
504
+ self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
505
+ window_kv_indices_full
506
+ )
507
+ )
508
+ state_indices = window_kv_indices_swa.cpu().numpy()
509
+ state_indices = kv_to_page_indices(state_indices, page_size)
510
+ elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
511
+ seq_len = len(decode_req.req.origin_input_ids)
512
+ kv_indices_full = self.req_to_token_pool.req_to_token[
513
+ decode_req.req.req_pool_idx, :seq_len
514
+ ]
515
+ state_indices = kv_indices_full.cpu().numpy()
516
+ state_indices = kv_to_page_indices(state_indices, page_size)
517
+ else:
518
+ state_indices = None
417
519
 
418
520
  decode_req.metadata_buffer_index = (
419
521
  self.req_to_metadata_buffer_idx_allocator.alloc()
420
522
  )
421
523
  assert decode_req.metadata_buffer_index is not None
422
- page_indices = kv_to_page_indices(
423
- kv_indices, self.token_to_kv_pool_allocator.page_size
524
+ page_indices = kv_to_page_indices(kv_indices, page_size)
525
+ decode_req.kv_receiver.init(
526
+ page_indices, decode_req.metadata_buffer_index, state_indices
424
527
  )
425
- decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
426
-
528
+ decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
427
529
  preallocated_reqs.append(decode_req)
428
530
  indices_to_remove.add(i)
429
531
  decode_req.req.time_stats.decode_transfer_queue_entry_time = (
@@ -503,7 +605,10 @@ class DecodePreallocQueue:
503
605
 
504
606
  def _pre_alloc(self, req: Req) -> torch.Tensor:
505
607
  """Pre-allocate the memory for req_to_token and token_kv_pool"""
506
- req_pool_indices = self.req_to_token_pool.alloc(1)
608
+ if isinstance(self.req_to_token_pool, HybridMambaDecodeReqToTokenPool):
609
+ req_pool_indices = self.req_to_token_pool.alloc(1, [req])
610
+ else:
611
+ req_pool_indices = self.req_to_token_pool.alloc(1)
507
612
 
508
613
  assert (
509
614
  req_pool_indices is not None
@@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender):
48
48
  def send(
49
49
  self,
50
50
  kv_indices: npt.NDArray[np.int32],
51
+ state_indices: Optional[List[int]] = None,
51
52
  ):
52
53
  self.has_sent = True
53
- logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
54
+ logger.debug(
55
+ f"FakeKVSender send with kv_indices: {kv_indices}, state_indices: {state_indices}"
56
+ )
54
57
 
55
58
  def failure_exception(self):
56
59
  raise Exception("Fake KVSender Exception")
@@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver):
75
78
  logger.debug("FakeKVReceiver poll success")
76
79
  return KVPoll.Success
77
80
 
78
- def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
81
+ def init(
82
+ self,
83
+ kv_indices: list[int],
84
+ aux_index: Optional[int] = None,
85
+ state_indices: Optional[List[int]] = None,
86
+ ):
79
87
  self.has_init = True
80
88
  logger.debug(
81
- f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
89
+ f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
82
90
  )
83
91
 
84
92
  def failure_exception(self):
@@ -58,6 +58,7 @@ class TransferKVChunk:
58
58
  index_slice: slice
59
59
  is_last: bool
60
60
  prefill_aux_index: Optional[int]
61
+ state_indices: Optional[List[int]]
61
62
 
62
63
 
63
64
  # decode
@@ -69,6 +70,7 @@ class TransferInfo:
69
70
  mooncake_session_id: str
70
71
  dst_kv_indices: npt.NDArray[np.int32]
71
72
  dst_aux_index: int
73
+ dst_state_indices: List[int]
72
74
  required_dst_info_num: int
73
75
  is_dummy: bool
74
76
 
@@ -78,9 +80,14 @@ class TransferInfo:
78
80
  is_dummy = True
79
81
  dst_kv_indices = np.array([], dtype=np.int32)
80
82
  dst_aux_index = None
83
+ dst_state_indices = []
81
84
  else:
82
85
  dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
83
86
  dst_aux_index = int(msg[5].decode("ascii"))
87
+ if msg[6] == b"":
88
+ dst_state_indices = []
89
+ else:
90
+ dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32))
84
91
  is_dummy = False
85
92
  return cls(
86
93
  room=int(msg[0].decode("ascii")),
@@ -89,7 +96,8 @@ class TransferInfo:
89
96
  mooncake_session_id=msg[3].decode("ascii"),
90
97
  dst_kv_indices=dst_kv_indices,
91
98
  dst_aux_index=dst_aux_index,
92
- required_dst_info_num=int(msg[6].decode("ascii")),
99
+ dst_state_indices=dst_state_indices,
100
+ required_dst_info_num=int(msg[7].decode("ascii")),
93
101
  is_dummy=is_dummy,
94
102
  )
95
103
 
@@ -103,6 +111,7 @@ class KVArgsRegisterInfo:
103
111
  mooncake_session_id: str
104
112
  dst_kv_ptrs: list[int]
105
113
  dst_aux_ptrs: list[int]
114
+ dst_state_data_ptrs: list[int]
106
115
  dst_tp_rank: int
107
116
  dst_attn_tp_size: int
108
117
  dst_kv_item_len: int
@@ -116,9 +125,10 @@ class KVArgsRegisterInfo:
116
125
  mooncake_session_id=msg[3].decode("ascii"),
117
126
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
118
127
  dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
119
- dst_tp_rank=int(msg[6].decode("ascii")),
120
- dst_attn_tp_size=int(msg[7].decode("ascii")),
121
- dst_kv_item_len=int(msg[8].decode("ascii")),
128
+ dst_state_data_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
129
+ dst_tp_rank=int(msg[7].decode("ascii")),
130
+ dst_attn_tp_size=int(msg[8].decode("ascii")),
131
+ dst_kv_item_len=int(msg[9].decode("ascii")),
122
132
  )
123
133
 
124
134
 
@@ -180,6 +190,9 @@ class MooncakeKVManager(CommonKVManager):
180
190
  )
181
191
  for _ in range(transfer_queue_size)
182
192
  ]
193
+ self.state_executors = concurrent.futures.ThreadPoolExecutor(
194
+ transfer_thread_pool_size // transfer_queue_size
195
+ )
183
196
  for queue, executor in zip(self.transfer_queues, self.executors):
184
197
  threading.Thread(
185
198
  target=self.transfer_worker, args=(queue, executor), daemon=True
@@ -239,6 +252,12 @@ class MooncakeKVManager(CommonKVManager):
239
252
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
240
253
  )
241
254
 
255
+ # Batch register state/extra pool data buffers
256
+ if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
257
+ self.engine.batch_register(
258
+ self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
259
+ )
260
+
242
261
  def _transfer_data(self, mooncake_session_id, transfer_blocks):
243
262
  if not transfer_blocks:
244
263
  return 0
@@ -248,17 +267,23 @@ class MooncakeKVManager(CommonKVManager):
248
267
  mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
249
268
  )
250
269
 
251
- def send_kvcache(
270
+ def _send_kvcache_generic(
252
271
  self,
253
272
  mooncake_session_id: str,
254
- prefill_kv_indices: npt.NDArray[np.int32],
255
- dst_kv_ptrs: list[int],
256
- dst_kv_indices: npt.NDArray[np.int32],
273
+ src_data_ptrs: list[int],
274
+ dst_data_ptrs: list[int],
275
+ item_lens: list[int],
276
+ prefill_data_indices: npt.NDArray[np.int32],
277
+ dst_data_indices: npt.NDArray[np.int32],
257
278
  executor: concurrent.futures.ThreadPoolExecutor,
258
- ):
259
- # Group by indices
279
+ ) -> int:
280
+ """
281
+ Generic KV cache transfer supporting both MHA and MLA architectures.
282
+ This method is used by both send_kvcache (full pool) and maybe_send_extra.
283
+ """
284
+ # Group by indices for optimization
260
285
  prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
261
- prefill_kv_indices, dst_kv_indices
286
+ prefill_data_indices, dst_data_indices
262
287
  )
263
288
 
264
289
  layers_params = None
@@ -266,9 +291,9 @@ class MooncakeKVManager(CommonKVManager):
266
291
  # pp is not supported on the decode side yet
267
292
  if self.is_mla_backend:
268
293
  src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
269
- self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
294
+ self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
270
295
  )
271
- kv_item_len = self.kv_args.kv_item_lens[0]
296
+ kv_item_len = item_lens[0]
272
297
  layers_params = [
273
298
  (
274
299
  src_kv_ptrs[layer_id],
@@ -279,9 +304,9 @@ class MooncakeKVManager(CommonKVManager):
279
304
  ]
280
305
  else:
281
306
  src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
282
- self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
307
+ self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
283
308
  )
284
- kv_item_len = self.kv_args.kv_item_lens[0]
309
+ kv_item_len = item_lens[0]
285
310
  layers_params = [
286
311
  (
287
312
  src_k_ptrs[layer_id],
@@ -345,6 +370,24 @@ class MooncakeKVManager(CommonKVManager):
345
370
 
346
371
  return 0
347
372
 
373
+ def send_kvcache(
374
+ self,
375
+ mooncake_session_id: str,
376
+ prefill_kv_indices: npt.NDArray[np.int32],
377
+ dst_kv_ptrs: list[int],
378
+ dst_kv_indices: npt.NDArray[np.int32],
379
+ executor: concurrent.futures.ThreadPoolExecutor,
380
+ ):
381
+ return self._send_kvcache_generic(
382
+ mooncake_session_id=mooncake_session_id,
383
+ src_data_ptrs=self.kv_args.kv_data_ptrs,
384
+ dst_data_ptrs=dst_kv_ptrs,
385
+ item_lens=self.kv_args.kv_item_lens,
386
+ prefill_data_indices=prefill_kv_indices,
387
+ dst_data_indices=dst_kv_indices,
388
+ executor=executor,
389
+ )
390
+
348
391
  def send_kvcache_slice(
349
392
  self,
350
393
  mooncake_session_id: str,
@@ -593,6 +636,58 @@ class MooncakeKVManager(CommonKVManager):
593
636
  f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
594
637
  )
595
638
 
639
+ def maybe_send_extra(
640
+ self,
641
+ req: TransferInfo,
642
+ prefill_state_indices: list[int],
643
+ dst_state_data_ptrs: list[int],
644
+ ):
645
+ """Send state or extra pool data with type-specific handling."""
646
+ state_type = getattr(self.kv_args, "state_type", "none")
647
+
648
+ if state_type == "mamba":
649
+ return self._send_mamba_state(
650
+ req,
651
+ prefill_state_indices,
652
+ dst_state_data_ptrs,
653
+ )
654
+ elif state_type in ["swa", "nsa"]:
655
+ # Reuse _send_kvcache_generic interface to send extra pool data
656
+ prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32)
657
+ dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32)
658
+ return self._send_kvcache_generic(
659
+ mooncake_session_id=req.mooncake_session_id,
660
+ src_data_ptrs=self.kv_args.state_data_ptrs,
661
+ dst_data_ptrs=dst_state_data_ptrs,
662
+ item_lens=self.kv_args.state_item_lens,
663
+ prefill_data_indices=prefill_state_indices,
664
+ dst_data_indices=dst_state_indices,
665
+ executor=self.state_executors,
666
+ )
667
+ else:
668
+ return 0
669
+
670
+ def _send_mamba_state(
671
+ self,
672
+ req: TransferInfo,
673
+ prefill_mamba_index: list[int],
674
+ dst_state_data_ptrs: list[int],
675
+ ):
676
+ """Transfer Mamba states."""
677
+ assert len(prefill_mamba_index) == 1, "Mamba should have single state index"
678
+
679
+ transfer_blocks = []
680
+ prefill_state_data_ptrs = self.kv_args.state_data_ptrs
681
+ prefill_state_item_lens = self.kv_args.state_item_lens
682
+
683
+ for i, dst_state_ptr in enumerate(dst_state_data_ptrs):
684
+ length = prefill_state_item_lens[i]
685
+ src_addr = prefill_state_data_ptrs[i] + length * int(prefill_mamba_index[0])
686
+ dst_addr = dst_state_ptr + length * int(req.dst_state_indices[0])
687
+ transfer_blocks.append((src_addr, dst_addr, length))
688
+
689
+ return self._transfer_data(req.mooncake_session_id, transfer_blocks)
690
+
596
691
  def sync_status_to_decode_endpoint(
597
692
  self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
598
693
  ):
@@ -702,6 +797,21 @@ class MooncakeKVManager(CommonKVManager):
702
797
  break
703
798
 
704
799
  if kv_chunk.is_last:
800
+ if kv_chunk.state_indices is not None:
801
+ if not self.is_mla_backend and (
802
+ self.attn_tp_size
803
+ != target_rank_registration_info.dst_attn_tp_size
804
+ ):
805
+ raise RuntimeError(
806
+ f"PD Disaggregation does NOT support PD different TP sizes for non-MLA hybrid models yet."
807
+ )
808
+
809
+ self.maybe_send_extra(
810
+ req,
811
+ kv_chunk.state_indices,
812
+ target_rank_registration_info.dst_state_data_ptrs,
813
+ )
814
+
705
815
  if self.pp_group.is_last_rank:
706
816
  # Only the last chunk we need to send the aux data
707
817
  ret = self.send_aux(
@@ -765,7 +875,7 @@ class MooncakeKVManager(CommonKVManager):
765
875
  )
766
876
  continue
767
877
  else:
768
- required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
878
+ required_dst_info_num = int(waiting_req_bytes[7].decode("ascii"))
769
879
  room = int(room)
770
880
  if room not in self.transfer_infos:
771
881
  self.transfer_infos[room] = {}
@@ -876,6 +986,7 @@ class MooncakeKVManager(CommonKVManager):
876
986
  index_slice: slice,
877
987
  is_last: bool,
878
988
  aux_index: Optional[int] = None,
989
+ state_indices: Optional[List[int]] = None,
879
990
  ):
880
991
  assert self.disaggregation_mode == DisaggregationMode.PREFILL
881
992
  assert not is_last or (is_last and aux_index is not None)
@@ -909,6 +1020,7 @@ class MooncakeKVManager(CommonKVManager):
909
1020
  index_slice=index_slice,
910
1021
  is_last=is_last,
911
1022
  prefill_aux_index=aux_index,
1023
+ state_indices=state_indices,
912
1024
  )
913
1025
  )
914
1026
 
@@ -989,6 +1101,7 @@ class MooncakeKVSender(CommonKVSender):
989
1101
  def send(
990
1102
  self,
991
1103
  kv_indices: npt.NDArray[np.int32],
1104
+ state_indices: Optional[List[int]] = None,
992
1105
  ):
993
1106
  index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
994
1107
  self.curr_idx += len(kv_indices)
@@ -1008,6 +1121,7 @@ class MooncakeKVSender(CommonKVSender):
1008
1121
  index_slice,
1009
1122
  True,
1010
1123
  aux_index=self.aux_index,
1124
+ state_indices=state_indices,
1011
1125
  )
1012
1126
 
1013
1127
  def poll(self) -> KVPoll:
@@ -1110,6 +1224,9 @@ class MooncakeKVReceiver(CommonKVReceiver):
1110
1224
  packed_aux_data_ptrs = b"".join(
1111
1225
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
1112
1226
  )
1227
+ packed_state_data_ptrs = b"".join(
1228
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs
1229
+ )
1113
1230
  # Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
1114
1231
  tp_rank = self.kv_mgr.kv_args.engine_rank
1115
1232
  kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
@@ -1127,13 +1244,19 @@ class MooncakeKVReceiver(CommonKVReceiver):
1127
1244
  self.session_id.encode("ascii"),
1128
1245
  packed_kv_data_ptrs,
1129
1246
  packed_aux_data_ptrs,
1247
+ packed_state_data_ptrs,
1130
1248
  dst_tp_rank,
1131
1249
  dst_attn_tp_size,
1132
1250
  dst_kv_item_len,
1133
1251
  ]
1134
1252
  )
1135
1253
 
1136
- def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
1254
+ def init(
1255
+ self,
1256
+ kv_indices: npt.NDArray[np.int32],
1257
+ aux_index: Optional[int] = None,
1258
+ state_indices: Optional[List[int]] = None,
1259
+ ):
1137
1260
  for bootstrap_info in self.bootstrap_infos:
1138
1261
  sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
1139
1262
  is_dummy = bootstrap_info["is_dummy"]
@@ -1147,6 +1270,14 @@ class MooncakeKVReceiver(CommonKVReceiver):
1147
1270
  self.session_id.encode("ascii"),
1148
1271
  kv_indices.tobytes() if not is_dummy else b"",
1149
1272
  str(aux_index).encode("ascii") if not is_dummy else b"",
1273
+ (
1274
+ np.array(
1275
+ state_indices,
1276
+ dtype=np.int32,
1277
+ ).tobytes()
1278
+ if not is_dummy and state_indices is not None
1279
+ else b""
1280
+ ),
1150
1281
  str(self.required_dst_info_num).encode("ascii"),
1151
1282
  ]
1152
1283
  )
@@ -704,6 +704,7 @@ class NixlKVSender(CommonKVSender):
704
704
  def send(
705
705
  self,
706
706
  kv_indices: npt.NDArray[np.int32],
707
+ state_indices: Optional[List[int]] = None,
707
708
  ):
708
709
  index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
709
710
  self.curr_idx += len(kv_indices)
@@ -755,7 +756,12 @@ class NixlKVReceiver(CommonKVReceiver):
755
756
  self.bootstrap_room
756
757
  )
757
758
 
758
- def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
759
+ def init(
760
+ self,
761
+ kv_indices: npt.NDArray[np.int32],
762
+ aux_index: Optional[int] = None,
763
+ state_indices: Optional[List[int]] = None,
764
+ ):
759
765
  for bootstrap_info in self.bootstrap_infos:
760
766
  logger.debug(
761
767
  f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"