sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +375 -51
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import asyncio
4
4
  import concurrent.futures
5
+ import ctypes
5
6
  import dataclasses
6
7
  import logging
7
8
  import os
@@ -34,6 +35,7 @@ from sglang.srt.disaggregation.common.utils import (
34
35
  )
35
36
  from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
36
37
  from sglang.srt.disaggregation.utils import DisaggregationMode
38
+ from sglang.srt.distributed import get_pp_group
37
39
  from sglang.srt.layers.dp_attention import (
38
40
  get_attention_dp_rank,
39
41
  get_attention_dp_size,
@@ -137,7 +139,29 @@ class KVArgsRegisterInfo:
137
139
  )
138
140
 
139
141
 
142
+ class AuxDataCodec:
143
+ """Handles serialization and deserialization of auxiliary data buffers"""
144
+
145
+ @staticmethod
146
+ def serialize_data_from_buffer(src_addr, data_length):
147
+ """Serialize data from memory buffer to bytes"""
148
+ buffer = (ctypes.c_byte * data_length).from_address(src_addr)
149
+ return bytes(buffer)
150
+
151
+ @staticmethod
152
+ def deserialize_data_to_buffer(kv_args, buffer_index, aux_index, data):
153
+ """Deserialize bytes into target memory buffer"""
154
+ dst_aux_ptr = kv_args.aux_data_ptrs[buffer_index]
155
+ item_len = kv_args.aux_item_lens[buffer_index]
156
+ dst_addr = dst_aux_ptr + item_len * aux_index
157
+ buffer = (ctypes.c_byte * len(data)).from_address(dst_addr)
158
+ buffer[:] = data
159
+ return
160
+
161
+
140
162
  class MooncakeKVManager(BaseKVManager):
163
+ AUX_DATA_HEADER = b"AUX_DATA"
164
+
141
165
  def __init__(
142
166
  self,
143
167
  args: KVArgs,
@@ -180,6 +204,7 @@ class MooncakeKVManager(BaseKVManager):
180
204
  self.session_failures = defaultdict(int)
181
205
  self.failed_sessions = set()
182
206
  self.session_lock = threading.Lock()
207
+ self.pp_group = get_pp_group()
183
208
  # Determine the number of threads to use for kv sender
184
209
  cpu_count = os.cpu_count()
185
210
  transfer_thread_pool_size = get_int_env_var(
@@ -281,21 +306,10 @@ class MooncakeKVManager(BaseKVManager):
281
306
  if not transfer_blocks:
282
307
  return 0
283
308
 
284
- # TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
285
- if self.enable_custom_mem_pool:
286
- # batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily
287
- for src_addr, dst_addr, length in transfer_blocks:
288
- status = self.engine.transfer_sync(
289
- mooncake_session_id, src_addr, dst_addr, length
290
- )
291
- if status != 0:
292
- return status
293
- return 0
294
- else:
295
- src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
296
- return self.engine.batch_transfer_sync(
297
- mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
298
- )
309
+ src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
310
+ return self.engine.batch_transfer_sync(
311
+ mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
312
+ )
299
313
 
300
314
  def send_kvcache(
301
315
  self,
@@ -313,11 +327,11 @@ class MooncakeKVManager(BaseKVManager):
313
327
  layers_params = None
314
328
 
315
329
  # pp is not supported on the decode side yet
330
+ start_layer = self.kv_args.prefill_start_layer
331
+ end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
316
332
  if self.is_mla_backend:
317
333
  src_kv_ptrs = self.kv_args.kv_data_ptrs
318
334
  layers_per_pp_stage = len(src_kv_ptrs)
319
- start_layer = self.pp_rank * layers_per_pp_stage
320
- end_layer = start_layer + layers_per_pp_stage
321
335
  dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
322
336
  kv_item_len = self.kv_args.kv_item_lens[0]
323
337
  layers_params = [
@@ -330,17 +344,15 @@ class MooncakeKVManager(BaseKVManager):
330
344
  ]
331
345
  else:
332
346
  num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
347
+ dst_num_total_layers = num_kv_layers * self.pp_size
333
348
  src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
334
349
  src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
335
350
  layers_per_pp_stage = len(src_k_ptrs)
336
- start_layer = self.pp_rank * layers_per_pp_stage
337
- end_layer = start_layer + layers_per_pp_stage
338
351
  dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
339
352
  dst_v_ptrs = dst_kv_ptrs[
340
- num_kv_layers + start_layer : num_kv_layers + end_layer
353
+ dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
341
354
  ]
342
355
  kv_item_len = self.kv_args.kv_item_lens[0]
343
-
344
356
  layers_params = [
345
357
  (
346
358
  src_k_ptrs[layer_id],
@@ -452,6 +464,7 @@ class MooncakeKVManager(BaseKVManager):
452
464
 
453
465
  # pp is not supported on the decode side yet
454
466
  num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
467
+ dst_num_total_layers = num_kv_layers * self.pp_size
455
468
  src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
456
469
  src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
457
470
  layers_per_pp_stage = len(src_k_ptrs)
@@ -459,7 +472,7 @@ class MooncakeKVManager(BaseKVManager):
459
472
  end_layer = start_layer + layers_per_pp_stage
460
473
  dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
461
474
  dst_v_ptrs = dst_kv_ptrs[
462
- num_kv_layers + start_layer : num_kv_layers + end_layer
475
+ dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
463
476
  ]
464
477
 
465
478
  # Calculate precise byte offset and length for the sub-slice within the token
@@ -569,11 +582,14 @@ class MooncakeKVManager(BaseKVManager):
569
582
 
570
583
  def send_aux(
571
584
  self,
572
- mooncake_session_id: str,
585
+ req: TransferInfo,
573
586
  prefill_aux_index: int,
574
587
  dst_aux_ptrs: list[int],
575
- dst_aux_index: int,
576
588
  ):
589
+ # TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
590
+ if self.enable_custom_mem_pool:
591
+ return self.send_aux_tcp(req, prefill_aux_index, dst_aux_ptrs)
592
+
577
593
  transfer_blocks = []
578
594
  prefill_aux_ptrs = self.kv_args.aux_data_ptrs
579
595
  prefill_aux_item_lens = self.kv_args.aux_item_lens
@@ -581,10 +597,59 @@ class MooncakeKVManager(BaseKVManager):
581
597
  for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
582
598
  length = prefill_aux_item_lens[i]
583
599
  src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
584
- dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
600
+ dst_addr = dst_aux_ptrs[i] + length * req.dst_aux_index
585
601
  transfer_blocks.append((src_addr, dst_addr, length))
586
602
 
587
- return self._transfer_data(mooncake_session_id, transfer_blocks)
603
+ return self._transfer_data(req.mooncake_session_id, transfer_blocks)
604
+
605
+ def send_aux_tcp(
606
+ self,
607
+ req: TransferInfo,
608
+ prefill_aux_index: int,
609
+ dst_aux_ptrs: list[int],
610
+ ):
611
+ prefill_aux_ptrs = self.kv_args.aux_data_ptrs
612
+ prefill_aux_item_lens = self.kv_args.aux_item_lens
613
+
614
+ for i in range(len(prefill_aux_ptrs)):
615
+ length = prefill_aux_item_lens[i]
616
+ src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
617
+ data = AuxDataCodec.serialize_data_from_buffer(src_addr, length)
618
+
619
+ self.send_aux_data_to_endpoint(
620
+ remote=req.endpoint,
621
+ dst_port=req.dst_port,
622
+ room=req.room,
623
+ buffer_index=i,
624
+ aux_index=req.dst_aux_index,
625
+ data=data,
626
+ )
627
+
628
+ return 0
629
+
630
+ def send_aux_data_to_endpoint(
631
+ self,
632
+ remote: str,
633
+ dst_port: int,
634
+ room: int,
635
+ buffer_index: int,
636
+ aux_index: int,
637
+ data: bytes,
638
+ ):
639
+ socket = self._connect(
640
+ format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
641
+ )
642
+
643
+ socket.send_multipart(
644
+ [
645
+ MooncakeKVManager.AUX_DATA_HEADER,
646
+ str(room).encode("ascii"),
647
+ str(buffer_index).encode("ascii"),
648
+ str(aux_index).encode("ascii"),
649
+ struct.pack(">I", len(data)),
650
+ data,
651
+ ]
652
+ )
588
653
 
589
654
  def sync_status_to_decode_endpoint(
590
655
  self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
@@ -612,7 +677,7 @@ class MooncakeKVManager(BaseKVManager):
612
677
  )
613
678
  polls = []
614
679
  dst_ranks_infos = []
615
- local_rank = self.kv_args.engine_rank
680
+ local_rank = self.attn_tp_rank * self.pp_size + self.pp_rank
616
681
  for req in reqs_to_be_processed:
617
682
  if not req.is_dummy:
618
683
  # Early exit if the request has failed
@@ -695,13 +760,13 @@ class MooncakeKVManager(BaseKVManager):
695
760
  break
696
761
 
697
762
  if kv_chunk.is_last:
698
- # Only the last chunk we need to send the aux data
699
- ret = self.send_aux(
700
- req.mooncake_session_id,
701
- kv_chunk.prefill_aux_index,
702
- target_rank_registration_info.dst_aux_ptrs,
703
- req.dst_aux_index,
704
- )
763
+ if self.pp_group.is_last_rank:
764
+ # Only the last chunk we need to send the aux data
765
+ ret = self.send_aux(
766
+ req,
767
+ kv_chunk.prefill_aux_index,
768
+ target_rank_registration_info.dst_aux_ptrs,
769
+ )
705
770
  polls.append(True if ret == 0 else False)
706
771
  dst_ranks_infos.append(
707
772
  (req.endpoint, req.dst_port, req.room)
@@ -776,15 +841,38 @@ class MooncakeKVManager(BaseKVManager):
776
841
 
777
842
  threading.Thread(target=bootstrap_thread).start()
778
843
 
844
+ def _handle_aux_data(self, msg: List[bytes]):
845
+ """Handle AUX_DATA messages received by the decode thread."""
846
+ room = int(msg[1].decode("ascii"))
847
+ buffer_index = int(msg[2].decode("ascii"))
848
+ aux_index = int(msg[3].decode("ascii"))
849
+ data_length = struct.unpack(">I", msg[4])[0]
850
+ data = msg[5]
851
+
852
+ if len(data) != data_length:
853
+ logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
854
+ return
855
+
856
+ AuxDataCodec.deserialize_data_to_buffer(
857
+ self.kv_args, buffer_index, aux_index, data
858
+ )
859
+
860
+ logger.debug(
861
+ f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
862
+ )
863
+
779
864
  def start_decode_thread(self):
780
865
  self.rank_port = get_free_port()
781
866
  self._bind_server_socket()
782
867
 
783
868
  def decode_thread():
784
869
  while True:
785
- (bootstrap_room, status, prefill_rank) = (
786
- self.server_socket.recv_multipart()
787
- )
870
+ msg = self.server_socket.recv_multipart()
871
+ if msg[0] == MooncakeKVManager.AUX_DATA_HEADER:
872
+ self._handle_aux_data(msg)
873
+ continue
874
+
875
+ (bootstrap_room, status, prefill_rank) = msg
788
876
  status = int(status.decode("ascii"))
789
877
  bootstrap_room = int(bootstrap_room.decode("ascii"))
790
878
  prefill_rank = int(prefill_rank.decode("ascii"))
@@ -798,10 +886,7 @@ class MooncakeKVManager(BaseKVManager):
798
886
  arrived_response_num = len(
799
887
  self.prefill_response_tracker[bootstrap_room]
800
888
  )
801
- if (
802
- self.is_mla_backend
803
- or arrived_response_num == expected_response_num
804
- ):
889
+ if arrived_response_num == expected_response_num:
805
890
  self.update_status(bootstrap_room, KVPoll.Success)
806
891
  elif status == KVPoll.Failed:
807
892
  self.record_failure(
@@ -1183,7 +1268,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
1183
1268
  self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
1184
1269
  )
1185
1270
  self.required_dst_info_num = 1
1186
- self.required_prefill_response_num = 1
1271
+ self.required_prefill_response_num = 1 * (
1272
+ self.prefill_pp_size // self.kv_mgr.pp_size
1273
+ )
1187
1274
  self.target_tp_ranks = [self.target_tp_rank]
1188
1275
  elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
1189
1276
  if not self.kv_mgr.is_mla_backend:
@@ -1196,7 +1283,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
1196
1283
  self.required_dst_info_num = (
1197
1284
  self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
1198
1285
  )
1199
- self.required_prefill_response_num = 1
1286
+ self.required_prefill_response_num = 1 * (
1287
+ self.prefill_pp_size // self.kv_mgr.pp_size
1288
+ )
1200
1289
  self.target_tp_ranks = [self.target_tp_rank]
1201
1290
  else:
1202
1291
  if not self.kv_mgr.is_mla_backend:
@@ -1219,9 +1308,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
1219
1308
  # or the KVPoll will never be set correctly
1220
1309
  self.target_tp_rank = self.target_tp_ranks[0]
1221
1310
  self.required_dst_info_num = 1
1222
- self.required_prefill_response_num = (
1223
- self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
1224
- )
1311
+ if self.kv_mgr.is_mla_backend:
1312
+ self.required_prefill_response_num = (
1313
+ self.prefill_pp_size // self.kv_mgr.pp_size
1314
+ )
1315
+ else:
1316
+ self.required_prefill_response_num = (
1317
+ self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
1318
+ ) * (self.prefill_pp_size // self.kv_mgr.pp_size)
1225
1319
 
1226
1320
  if self.data_parallel_rank is not None:
1227
1321
  logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
@@ -1530,7 +1624,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1530
1624
  "rank_port": rank_port,
1531
1625
  }
1532
1626
  logger.debug(
1533
- f"Register prefill bootstrap: DP {dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
1627
+ f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
1534
1628
  )
1535
1629
 
1536
1630
  return web.Response(text="OK", status=200)
@@ -43,8 +43,13 @@ from sglang.srt.disaggregation.utils import (
43
43
  prepare_abort,
44
44
  )
45
45
  from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
46
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
47
- from sglang.srt.utils import require_mlp_sync
46
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
47
+ from sglang.srt.utils import (
48
+ DynamicGradMode,
49
+ broadcast_pyobj,
50
+ point_to_point_pyobj,
51
+ require_mlp_sync,
52
+ )
48
53
 
49
54
  if TYPE_CHECKING:
50
55
  from torch.distributed import ProcessGroup
@@ -107,6 +112,7 @@ class PrefillBootstrapQueue:
107
112
  kv_args.system_dp_rank = self.scheduler.dp_rank
108
113
  kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
109
114
  kv_args.prefill_pp_size = self.pp_size
115
+ kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer
110
116
  kv_data_ptrs, kv_data_lens, kv_item_lens = (
111
117
  self.token_to_kv_pool.get_contiguous_buf_infos()
112
118
  )
@@ -172,7 +178,7 @@ class PrefillBootstrapQueue:
172
178
  if len(req.origin_input_ids) > self.max_total_num_tokens:
173
179
  message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
174
180
  logger.error(message)
175
- prepare_abort(req, message)
181
+ prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST)
176
182
  self.scheduler.stream_output([req], req.return_logprob)
177
183
  return True
178
184
  return False
@@ -208,8 +214,8 @@ class PrefillBootstrapQueue:
208
214
  polls = poll_and_all_reduce(
209
215
  [req.disagg_kv_sender for req in self.queue], self.gloo_group
210
216
  )
211
- for i, (req, poll) in enumerate(zip(self.queue, polls)):
212
217
 
218
+ for i, (req, poll) in enumerate(zip(self.queue, polls)):
213
219
  if rids_to_check is not None:
214
220
  # if req not in reqs_info_to_check, skip
215
221
  if req.rid not in rids_to_check:
@@ -395,7 +401,10 @@ class SchedulerDisaggregationPrefillMixin:
395
401
  req.output_ids.append(next_token_id)
396
402
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
397
403
  self.disagg_prefill_inflight_queue.append(req)
398
- if logits_output.hidden_states is not None:
404
+ if (
405
+ logits_output is not None
406
+ and logits_output.hidden_states is not None
407
+ ):
399
408
  last_hidden_index = (
400
409
  hidden_state_offset + extend_input_len_per_req[i] - 1
401
410
  )
@@ -603,3 +612,250 @@ class SchedulerDisaggregationPrefillMixin:
603
612
  )
604
613
  return
605
614
  req.disagg_kv_sender.send(page_indices)
615
+
616
+ # PP
617
+ @DynamicGradMode()
618
+ def event_loop_pp_disagg_prefill(self: Scheduler):
619
+ """
620
+ An event loop for the prefill server in pipeline parallelism.
621
+
622
+ Rules:
623
+ 1. Each stage runs in the same order and is notified by the previous stage.
624
+ 2. Each send/recv operation is blocking and matched by the neighboring stage.
625
+
626
+ Regular Schedule:
627
+ ====================================================================
628
+ Stage i | Stage i+1
629
+ send ith req | recv ith req
630
+ send ith proxy | recv ith proxy
631
+ send prev (i+1)th carry | recv prev (i+1)th carry
632
+ ====================================================================
633
+
634
+ Prefill Server Schedule:
635
+ ====================================================================
636
+ Stage i | Stage i+1
637
+ send ith req | recv ith req
638
+ send ith bootstrap req | recv ith bootstrap req
639
+ send ith transferred req | recv ith transferred req
640
+ send ith proxy | recv ith proxy
641
+ send prev (i+1)th carry | recv prev (i+1)th carry
642
+ send prev (i+1)th release req | recv prev (i+1)th release req
643
+ ====================================================================
644
+
645
+ There are two additional elements compared to the regular schedule:
646
+
647
+ 1. Bootstrap Requests:
648
+ a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
649
+ b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
650
+ c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
651
+
652
+ 2. Transferred Requests + Release Requests:
653
+ a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
654
+ b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
655
+ c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
656
+ """
657
+ from sglang.srt.managers.scheduler import GenerationBatchResult
658
+
659
+ mbs = [None] * self.pp_size
660
+ last_mbs = [None] * self.pp_size
661
+ self.running_mbs = [
662
+ ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
663
+ ]
664
+ bids = [None] * self.pp_size
665
+ pp_outputs: Optional[PPProxyTensors] = None
666
+
667
+ # Either success or failed
668
+ bootstrapped_rids: List[str] = []
669
+ transferred_rids: List[str] = []
670
+ release_rids: Optional[List[str]] = None
671
+
672
+ # transferred microbatch
673
+ tmbs = [None] * self.pp_size
674
+
675
+ ENABLE_RELEASE = True # For debug
676
+
677
+ while True:
678
+ server_is_idle = True
679
+
680
+ for mb_id in range(self.pp_size):
681
+ self.running_batch = self.running_mbs[mb_id]
682
+ self.last_batch = last_mbs[mb_id]
683
+
684
+ recv_reqs = self.recv_requests()
685
+
686
+ self.process_input_requests(recv_reqs)
687
+
688
+ if self.pp_group.is_first_rank:
689
+ # First rank, pop the bootstrap reqs from the bootstrap queue
690
+ bootstrapped_reqs, failed_reqs = (
691
+ self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
692
+ return_failed_reqs=True
693
+ )
694
+ )
695
+ bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
696
+ req.rid for req in failed_reqs
697
+ ]
698
+ self.waiting_queue.extend(bootstrapped_reqs)
699
+ else:
700
+ # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
701
+ bootstrapped_rids = self.recv_pyobj_from_prev_stage()
702
+ bootstrapped_reqs = (
703
+ self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
704
+ rids_to_check=bootstrapped_rids
705
+ )
706
+ )
707
+ self.waiting_queue.extend(bootstrapped_reqs)
708
+
709
+ if self.pp_group.is_first_rank:
710
+ transferred_rids = self.get_transferred_rids()
711
+ # if other ranks,
712
+ else:
713
+ # 1. recv previous stage's transferred reqs info
714
+ prev_transferred_rids = self.recv_pyobj_from_prev_stage()
715
+ # 2. get the current stage's transferred reqs info
716
+ curr_transferred_rids = self.get_transferred_rids()
717
+ # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
718
+ transferred_rids = list(
719
+ set(prev_transferred_rids) & set(curr_transferred_rids)
720
+ )
721
+
722
+ tmbs[mb_id] = transferred_rids
723
+
724
+ self.process_prefill_chunk()
725
+ mbs[mb_id] = self.get_new_batch_prefill()
726
+ self.running_mbs[mb_id] = self.running_batch
727
+
728
+ self.cur_batch = mbs[mb_id]
729
+ if self.cur_batch:
730
+ server_is_idle = False
731
+ result = self.run_batch(self.cur_batch)
732
+
733
+ # send the outputs to the next step
734
+ if self.pp_group.is_last_rank:
735
+ if self.cur_batch:
736
+ next_token_ids, bids[mb_id] = (
737
+ result.next_token_ids,
738
+ result.bid,
739
+ )
740
+ pp_outputs = PPProxyTensors(
741
+ {
742
+ "next_token_ids": next_token_ids,
743
+ }
744
+ )
745
+ # send the output from the last round to let the next stage worker run post processing
746
+ self.pp_group.send_tensor_dict(
747
+ pp_outputs.tensors,
748
+ all_gather_group=self.attn_tp_group,
749
+ )
750
+
751
+ if ENABLE_RELEASE:
752
+ if self.pp_group.is_last_rank:
753
+ # At the last stage, all stages has reached the consensus to release memory for transferred_rids
754
+ release_rids = transferred_rids
755
+ # send to the first rank
756
+ self.send_pyobj_to_next_stage(release_rids)
757
+
758
+ # receive outputs and post-process (filter finished reqs) the coming microbatch
759
+ next_mb_id = (mb_id + 1) % self.pp_size
760
+ next_pp_outputs = None
761
+ next_release_rids = None
762
+
763
+ if mbs[next_mb_id] is not None:
764
+ next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
765
+ self.pp_group.recv_tensor_dict(
766
+ all_gather_group=self.attn_tp_group
767
+ )
768
+ )
769
+ mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
770
+ output_result = GenerationBatchResult(
771
+ logits_output=None,
772
+ pp_hidden_states_proxy_tensors=None,
773
+ next_token_ids=next_pp_outputs["next_token_ids"],
774
+ extend_input_len_per_req=None,
775
+ extend_logprob_start_len_per_req=None,
776
+ bid=bids[next_mb_id],
777
+ can_run_cuda_graph=result.can_run_cuda_graph,
778
+ )
779
+ self.process_batch_result_disagg_prefill(
780
+ mbs[next_mb_id], output_result
781
+ )
782
+
783
+ last_mbs[next_mb_id] = mbs[next_mb_id]
784
+
785
+ if ENABLE_RELEASE:
786
+ if tmbs[next_mb_id] is not None:
787
+ # recv consensus rids from the previous rank
788
+ next_release_rids = self.recv_pyobj_from_prev_stage()
789
+ self.process_disagg_prefill_inflight_queue(next_release_rids)
790
+
791
+ # carry the outputs to the next stage
792
+ if not self.pp_group.is_last_rank:
793
+ if self.cur_batch:
794
+ bids[mb_id] = result.bid
795
+ if pp_outputs:
796
+ # send the outputs from the last round to let the next stage worker run post processing
797
+ self.pp_group.send_tensor_dict(
798
+ pp_outputs.tensors,
799
+ all_gather_group=self.attn_tp_group,
800
+ )
801
+ if ENABLE_RELEASE:
802
+ if release_rids is not None:
803
+ self.send_pyobj_to_next_stage(release_rids)
804
+
805
+ if not self.pp_group.is_last_rank:
806
+ # send out reqs to the next stage
807
+ self.send_pyobj_to_next_stage(recv_reqs)
808
+ self.send_pyobj_to_next_stage(bootstrapped_rids)
809
+ self.send_pyobj_to_next_stage(transferred_rids)
810
+
811
+ # send out proxy tensors to the next stage
812
+ if self.cur_batch:
813
+ self.pp_group.send_tensor_dict(
814
+ result.pp_hidden_states_proxy_tensors,
815
+ all_gather_group=self.attn_tp_group,
816
+ )
817
+
818
+ pp_outputs = next_pp_outputs
819
+ release_rids = next_release_rids
820
+
821
+ self.running_batch.batch_is_full = False
822
+
823
+ if not ENABLE_RELEASE:
824
+ if len(self.disagg_prefill_inflight_queue) > 0:
825
+ self.process_disagg_prefill_inflight_queue()
826
+
827
+ # When the server is idle, self-check and re-init some states
828
+ if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
829
+ self.check_memory()
830
+ self.check_tree_cache()
831
+ self.new_token_ratio = self.init_new_token_ratio
832
+
833
+ def send_pyobj_to_next_stage(self, data):
834
+ if self.attn_tp_rank == 0:
835
+ dp_offset = self.attn_dp_rank * self.attn_tp_size
836
+ point_to_point_pyobj(
837
+ data,
838
+ self.pp_rank * self.tp_size + dp_offset,
839
+ self.world_group.device_group,
840
+ self.pp_rank * self.tp_size + dp_offset,
841
+ ((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset,
842
+ )
843
+
844
+ def recv_pyobj_from_prev_stage(self):
845
+ if self.attn_tp_rank == 0:
846
+ dp_offset = self.attn_dp_rank * self.attn_tp_size
847
+ data = point_to_point_pyobj(
848
+ [],
849
+ self.pp_rank * self.tp_size + dp_offset,
850
+ self.world_group.device_group,
851
+ ((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset,
852
+ self.pp_rank * self.tp_size + dp_offset,
853
+ )
854
+ else:
855
+ data = None
856
+
857
+ if self.tp_size != 1:
858
+ data = broadcast_pyobj(
859
+ data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0]
860
+ )
861
+ return data
@@ -99,7 +99,8 @@ class MetadataBuffers:
99
99
  # For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
100
100
  device = "npu"
101
101
  elif self.custom_mem_pool:
102
- device = "cuda"
102
+ # TODO(shangming): Fix me (use 'cuda') when nvlink_transport of Mooncake is bug-free
103
+ device = "cpu"
103
104
  with (
104
105
  torch.cuda.use_mem_pool(self.custom_mem_pool)
105
106
  if self.custom_mem_pool
@@ -398,7 +398,7 @@ class CustomAllreduce:
398
398
  else:
399
399
  # If warm up, mimic the allocation pattern since custom
400
400
  # allreduce is out-of-place.
401
- return torch.empty_like(input)
401
+ return torch.zeros_like(input)
402
402
  else:
403
403
  if _is_hip:
404
404
  # note: outside of cuda graph context,