sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,12 @@ from sglang.srt.disaggregation.common.utils import (
34
34
  )
35
35
  from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
36
36
  from sglang.srt.disaggregation.utils import DisaggregationMode
37
+ from sglang.srt.layers.dp_attention import (
38
+ get_attention_dp_rank,
39
+ get_attention_dp_size,
40
+ get_attention_tp_rank,
41
+ get_attention_tp_size,
42
+ )
37
43
  from sglang.srt.server_args import ServerArgs
38
44
  from sglang.srt.utils import (
39
45
  format_tcp_address,
@@ -113,7 +119,7 @@ class KVArgsRegisterInfo:
113
119
  dst_kv_ptrs: list[int]
114
120
  dst_aux_ptrs: list[int]
115
121
  dst_tp_rank: int
116
- dst_tp_size: int
122
+ dst_attn_tp_size: int
117
123
  dst_kv_item_len: int
118
124
 
119
125
  @classmethod
@@ -126,7 +132,7 @@ class KVArgsRegisterInfo:
126
132
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
127
133
  dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
128
134
  dst_tp_rank=int(msg[6].decode("ascii")),
129
- dst_tp_size=int(msg[7].decode("ascii")),
135
+ dst_attn_tp_size=int(msg[7].decode("ascii")),
130
136
  dst_kv_item_len=int(msg[8].decode("ascii")),
131
137
  )
132
138
 
@@ -147,13 +153,18 @@ class MooncakeKVManager(BaseKVManager):
147
153
  # for p/d multi node infer
148
154
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
149
155
  self.dist_init_addr = server_args.dist_init_addr
150
- self.tp_size = server_args.tp_size
151
- self.dp_size = server_args.dp_size
152
- self.enable_dp_attention = server_args.enable_dp_attention
153
- if not server_args.enable_dp_attention and server_args.dp_size != 1:
154
- raise ValueError(
155
- "If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
156
- )
156
+ self.attn_tp_size = get_attention_tp_size()
157
+ self.attn_tp_rank = get_attention_tp_rank()
158
+ self.attn_dp_size = get_attention_dp_size()
159
+ self.attn_dp_rank = get_attention_dp_rank()
160
+ self.system_dp_size = (
161
+ 1 if server_args.enable_dp_attention else server_args.dp_size
162
+ )
163
+ self.system_dp_rank = (
164
+ self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
165
+ )
166
+ self.pp_size = server_args.pp_size
167
+ self.pp_rank = self.kv_args.pp_rank
157
168
  self.request_status: Dict[int, KVPoll] = {}
158
169
  self.rank_port = None
159
170
  self.server_socket = zmq.Context().socket(zmq.PULL)
@@ -221,8 +232,9 @@ class MooncakeKVManager(BaseKVManager):
221
232
  )
222
233
  self.start_decode_thread()
223
234
  self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
224
- self.prefill_tp_size_table: Dict[str, int] = {}
235
+ self.prefill_attn_tp_size_table: Dict[str, int] = {}
225
236
  self.prefill_dp_size_table: Dict[str, int] = {}
237
+ self.prefill_pp_size_table: Dict[str, int] = {}
226
238
  # If a timeout happens on the decode side, it means decode instances
227
239
  # fail to receive the KV Cache transfer done signal after bootstrapping.
228
240
  # These timeout requests should be aborted to release the tree cache.
@@ -245,15 +257,17 @@ class MooncakeKVManager(BaseKVManager):
245
257
  )
246
258
 
247
259
  def register_buffer_to_engine(self):
248
- for kv_data_ptr, kv_data_len in zip(
249
- self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
250
- ):
251
- self.engine.register(kv_data_ptr, kv_data_len)
260
+ # Batch register KV data buffers
261
+ if self.kv_args.kv_data_ptrs and self.kv_args.kv_data_lens:
262
+ self.engine.batch_register(
263
+ self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
264
+ )
252
265
 
253
- for aux_data_ptr, aux_data_len in zip(
254
- self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
255
- ):
256
- self.engine.register(aux_data_ptr, aux_data_len)
266
+ # Batch register auxiliary data buffers
267
+ if self.kv_args.aux_data_ptrs and self.kv_args.aux_data_lens:
268
+ self.engine.batch_register(
269
+ self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
270
+ )
257
271
 
258
272
  @cache
259
273
  def _connect(self, endpoint: str, is_ipv6: bool = False):
@@ -296,43 +310,97 @@ class MooncakeKVManager(BaseKVManager):
296
310
  prefill_kv_indices, dst_kv_indices
297
311
  )
298
312
 
299
- num_layers = len(self.kv_args.kv_data_ptrs)
300
- layers_params = [
301
- (
302
- self.kv_args.kv_data_ptrs[layer_id],
303
- dst_kv_ptrs[layer_id],
304
- self.kv_args.kv_item_lens[layer_id],
305
- )
306
- for layer_id in range(num_layers)
307
- ]
313
+ layers_params = None
314
+
315
+ # pp is not supported on the decode side yet
316
+ if self.is_mla_backend:
317
+ src_kv_ptrs = self.kv_args.kv_data_ptrs
318
+ layers_per_pp_stage = len(src_kv_ptrs)
319
+ start_layer = self.pp_rank * layers_per_pp_stage
320
+ end_layer = start_layer + layers_per_pp_stage
321
+ dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
322
+ kv_item_len = self.kv_args.kv_item_lens[0]
323
+ layers_params = [
324
+ (
325
+ src_kv_ptrs[layer_id],
326
+ dst_kv_ptrs[layer_id],
327
+ kv_item_len,
328
+ )
329
+ for layer_id in range(layers_per_pp_stage)
330
+ ]
331
+ else:
332
+ num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
333
+ src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
334
+ src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
335
+ layers_per_pp_stage = len(src_k_ptrs)
336
+ start_layer = self.pp_rank * layers_per_pp_stage
337
+ end_layer = start_layer + layers_per_pp_stage
338
+ dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
339
+ dst_v_ptrs = dst_kv_ptrs[
340
+ num_kv_layers + start_layer : num_kv_layers + end_layer
341
+ ]
342
+ kv_item_len = self.kv_args.kv_item_lens[0]
308
343
 
309
- # Worker function for processing a single layer
310
- def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
344
+ layers_params = [
345
+ (
346
+ src_k_ptrs[layer_id],
347
+ dst_k_ptrs[layer_id],
348
+ kv_item_len,
349
+ )
350
+ for layer_id in range(layers_per_pp_stage)
351
+ ] + [
352
+ (
353
+ src_v_ptrs[layer_id],
354
+ dst_v_ptrs[layer_id],
355
+ kv_item_len,
356
+ )
357
+ for layer_id in range(layers_per_pp_stage)
358
+ ]
359
+ assert layers_params is not None
360
+
361
+ def set_transfer_blocks(
362
+ src_ptr: int, dst_ptr: int, item_len: int
363
+ ) -> List[Tuple[int, int, int]]:
311
364
  transfer_blocks = []
312
365
  for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
313
366
  src_addr = src_ptr + int(prefill_index[0]) * item_len
314
367
  dst_addr = dst_ptr + int(decode_index[0]) * item_len
315
368
  length = item_len * len(prefill_index)
316
369
  transfer_blocks.append((src_addr, dst_addr, length))
370
+ return transfer_blocks
317
371
 
372
+ # Worker function for processing a single layer
373
+ def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
374
+ transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
318
375
  return self._transfer_data(mooncake_session_id, transfer_blocks)
319
376
 
320
- futures = [
321
- executor.submit(
322
- process_layer,
323
- src_ptr,
324
- dst_ptr,
325
- item_len,
326
- )
327
- for (src_ptr, dst_ptr, item_len) in layers_params
328
- ]
377
+ # Worker function for processing all layers in a batch
378
+ def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
379
+ transfer_blocks = []
380
+ for src_ptr, dst_ptr, item_len in layers_params:
381
+ transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
382
+ return self._transfer_data(mooncake_session_id, transfer_blocks)
329
383
 
330
- for future in concurrent.futures.as_completed(futures):
331
- status = future.result()
332
- if status != 0:
333
- for f in futures:
334
- f.cancel()
335
- return status
384
+ if self.enable_custom_mem_pool:
385
+ futures = [
386
+ executor.submit(
387
+ process_layer,
388
+ src_ptr,
389
+ dst_ptr,
390
+ item_len,
391
+ )
392
+ for (src_ptr, dst_ptr, item_len) in layers_params
393
+ ]
394
+ for future in concurrent.futures.as_completed(futures):
395
+ status = future.result()
396
+ if status != 0:
397
+ for f in futures:
398
+ f.cancel()
399
+ return status
400
+ else:
401
+ # Combining all layers' params in one batch transfer is more efficient
402
+ # compared to using multiple threads
403
+ return process_layers(layers_params)
336
404
 
337
405
  return 0
338
406
 
@@ -343,7 +411,7 @@ class MooncakeKVManager(BaseKVManager):
343
411
  dst_kv_ptrs: list[int],
344
412
  dst_kv_indices: npt.NDArray[np.int64],
345
413
  dst_tp_rank: int,
346
- dst_tp_size: int,
414
+ dst_attn_tp_size: int,
347
415
  dst_kv_item_len: int,
348
416
  executor: concurrent.futures.ThreadPoolExecutor,
349
417
  ):
@@ -356,23 +424,22 @@ class MooncakeKVManager(BaseKVManager):
356
424
  This may introduce performance overhead (increased TTFT) for long sequences.
357
425
  """
358
426
  # Extract configuration
359
- local_tp_size = self.tp_size // self.dp_size
360
- local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size
427
+ local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size
361
428
  src_kv_item_len = self.kv_args.kv_item_lens[0]
362
- dst_tp_rank_in_group = dst_tp_rank % dst_tp_size
429
+ dst_tp_rank_in_group = dst_tp_rank % dst_attn_tp_size
363
430
  num_kv_heads = self.kv_args.kv_head_num
364
431
  num_layers = len(self.kv_args.kv_data_ptrs)
365
432
  page_size = self.kv_args.page_size
366
433
 
367
434
  # Calculate head distribution
368
435
  src_heads_per_rank = num_kv_heads
369
- dst_heads_per_rank = num_kv_heads * local_tp_size // dst_tp_size
436
+ dst_heads_per_rank = num_kv_heads * self.attn_tp_size // dst_attn_tp_size
370
437
  bytes_per_head_slice_to_send = (
371
438
  dst_kv_item_len // page_size // dst_heads_per_rank
372
439
  )
373
440
 
374
441
  # Determine slicing parameters based on TP configuration
375
- if local_tp_size > dst_tp_size:
442
+ if self.attn_tp_size > dst_attn_tp_size:
376
443
  # Send KVCache from multiple prefill instances to 1 decode instance
377
444
  src_head_start_offset = 0
378
445
  num_heads_to_send = src_heads_per_rank
@@ -383,35 +450,55 @@ class MooncakeKVManager(BaseKVManager):
383
450
  num_heads_to_send = dst_heads_per_rank
384
451
  dst_head_start_offset = 0
385
452
 
386
- layers_params = []
387
- for layer_id in range(num_layers):
388
- # Calculate precise byte offset and length for the sub-slice within the token
389
- src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
390
- dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
391
- heads_bytes_per_token_to_send = (
392
- num_heads_to_send * bytes_per_head_slice_to_send
453
+ # pp is not supported on the decode side yet
454
+ num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
455
+ src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
456
+ src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
457
+ layers_per_pp_stage = len(src_k_ptrs)
458
+ start_layer = self.pp_rank * layers_per_pp_stage
459
+ end_layer = start_layer + layers_per_pp_stage
460
+ dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
461
+ dst_v_ptrs = dst_kv_ptrs[
462
+ num_kv_layers + start_layer : num_kv_layers + end_layer
463
+ ]
464
+
465
+ # Calculate precise byte offset and length for the sub-slice within the token
466
+ src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
467
+ dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
468
+ heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
469
+
470
+ # Sanity check: The data sub-slice to be sent should fit into the dst buffer.
471
+ # This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
472
+ if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
473
+ logger.error(
474
+ f"[{mooncake_session_id}] slice size ({heads_bytes_per_token_to_send}) exceeds "
475
+ f"target token slot size ({dst_kv_item_len // page_size})"
393
476
  )
477
+ return -1
394
478
 
395
- # Sanity check: The data sub-slice to be sent should fit into the dst buffer.
396
- # This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
397
- if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
398
- logger.error(
399
- f"[{mooncake_session_id}] Layer {layer_id}: "
400
- f"slice size ({heads_bytes_per_token_to_send}) exceeds "
401
- f"target token slot size ({dst_kv_item_len // page_size})"
402
- )
403
- return -1
404
- layers_params.append(
405
- (
406
- self.kv_args.kv_data_ptrs[layer_id],
407
- dst_kv_ptrs[layer_id],
408
- src_kv_item_len,
409
- dst_kv_item_len,
410
- src_head_slice_offset,
411
- dst_head_slice_offset,
412
- heads_bytes_per_token_to_send,
413
- )
479
+ layers_params = [
480
+ (
481
+ src_k_ptrs[layer_id],
482
+ dst_k_ptrs[layer_id],
483
+ src_kv_item_len,
484
+ dst_kv_item_len,
485
+ src_head_slice_offset,
486
+ dst_head_slice_offset,
487
+ heads_bytes_per_token_to_send,
488
+ )
489
+ for layer_id in range(layers_per_pp_stage)
490
+ ] + [
491
+ (
492
+ src_v_ptrs[layer_id],
493
+ dst_v_ptrs[layer_id],
494
+ src_kv_item_len,
495
+ dst_kv_item_len,
496
+ src_head_slice_offset,
497
+ dst_head_slice_offset,
498
+ heads_bytes_per_token_to_send,
414
499
  )
500
+ for layer_id in range(layers_per_pp_stage)
501
+ ]
415
502
 
416
503
  def process_layer_tp_aware(layer_params):
417
504
  (
@@ -562,9 +649,9 @@ class MooncakeKVManager(BaseKVManager):
562
649
  target_rank_registration_info: KVArgsRegisterInfo = (
563
650
  self.decode_kv_args_table[req.mooncake_session_id]
564
651
  )
565
- local_tp_size = self.tp_size // self.dp_size
566
652
  if self.is_mla_backend or (
567
- local_tp_size == target_rank_registration_info.dst_tp_size
653
+ self.attn_tp_size
654
+ == target_rank_registration_info.dst_attn_tp_size
568
655
  ):
569
656
  ret = self.send_kvcache(
570
657
  req.mooncake_session_id,
@@ -580,7 +667,7 @@ class MooncakeKVManager(BaseKVManager):
580
667
  target_rank_registration_info.dst_kv_ptrs,
581
668
  chunked_dst_kv_indice,
582
669
  target_rank_registration_info.dst_tp_rank,
583
- target_rank_registration_info.dst_tp_size,
670
+ target_rank_registration_info.dst_attn_tp_size,
584
671
  target_rank_registration_info.dst_kv_item_len,
585
672
  executor,
586
673
  )
@@ -863,11 +950,16 @@ class MooncakeKVManager(BaseKVManager):
863
950
  url = f"http://{bootstrap_server_url}/route"
864
951
  payload = {
865
952
  "role": "Prefill",
866
- "tp_size": self.tp_size,
867
- "dp_size": self.dp_size,
953
+ "attn_tp_size": self.attn_tp_size,
954
+ "attn_tp_rank": self.attn_tp_rank,
955
+ "attn_dp_size": self.attn_dp_size,
956
+ "attn_dp_rank": self.attn_dp_rank,
957
+ "pp_size": self.pp_size,
958
+ "pp_rank": self.pp_rank,
959
+ "system_dp_size": self.system_dp_size,
960
+ "system_dp_rank": self.system_dp_rank,
868
961
  "rank_ip": self.local_ip,
869
962
  "rank_port": self.rank_port,
870
- "engine_rank": self.kv_args.engine_rank,
871
963
  }
872
964
 
873
965
  try:
@@ -890,10 +982,12 @@ class MooncakeKVManager(BaseKVManager):
890
982
  ]
891
983
  for k in keys_to_remove:
892
984
  del self.connection_pool[k]
893
- if failed_bootstrap_addr in self.prefill_tp_size_table:
894
- del self.prefill_tp_size_table[failed_bootstrap_addr]
985
+ if failed_bootstrap_addr in self.prefill_attn_tp_size_table:
986
+ del self.prefill_attn_tp_size_table[failed_bootstrap_addr]
895
987
  if failed_bootstrap_addr in self.prefill_dp_size_table:
896
988
  del self.prefill_dp_size_table[failed_bootstrap_addr]
989
+ if failed_bootstrap_addr in self.prefill_pp_size_table:
990
+ del self.prefill_pp_size_table[failed_bootstrap_addr]
897
991
 
898
992
  possible_affected_rooms = self.addr_to_rooms_tracker.get(
899
993
  failed_bootstrap_addr, []
@@ -915,7 +1009,7 @@ class MooncakeKVManager(BaseKVManager):
915
1009
  self.update_status(room, KVPoll.Failed)
916
1010
  affected_rooms.append(room)
917
1011
  logger.error(
918
- f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), affected {len(affected_rooms)} requests"
1012
+ f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected"
919
1013
  )
920
1014
 
921
1015
 
@@ -1042,10 +1136,16 @@ class MooncakeKVReceiver(BaseKVReceiver):
1042
1136
  self.data_parallel_rank = data_parallel_rank
1043
1137
 
1044
1138
  if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
1045
- self.prefill_tp_size, self.prefill_dp_size = (
1046
- self._get_prefill_parallel_info_from_server()
1047
- )
1048
- if self.prefill_tp_size is None or self.prefill_dp_size is None:
1139
+ (
1140
+ self.prefill_attn_tp_size,
1141
+ self.prefill_dp_size,
1142
+ self.prefill_pp_size,
1143
+ ) = self._get_prefill_parallel_info_from_server()
1144
+ if (
1145
+ self.prefill_attn_tp_size is None
1146
+ or self.prefill_dp_size is None
1147
+ or self.prefill_pp_size is None
1148
+ ):
1049
1149
  self.kv_mgr.record_failure(
1050
1150
  self.bootstrap_room,
1051
1151
  f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
@@ -1054,43 +1154,47 @@ class MooncakeKVReceiver(BaseKVReceiver):
1054
1154
  return
1055
1155
  else:
1056
1156
  logger.debug(
1057
- f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_tp_size}"
1157
+ f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
1058
1158
  )
1059
- self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
1060
- self.prefill_tp_size
1159
+ self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
1160
+ self.prefill_attn_tp_size
1061
1161
  )
1062
1162
  self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
1063
1163
  self.prefill_dp_size
1064
1164
  )
1165
+ self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
1166
+ self.prefill_pp_size
1167
+ )
1065
1168
  else:
1066
- self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
1169
+ self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
1067
1170
  self.bootstrap_addr
1068
1171
  ]
1069
1172
  self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
1070
1173
  self.bootstrap_addr
1071
1174
  ]
1175
+ self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
1176
+ self.bootstrap_addr
1177
+ ]
1072
1178
 
1073
1179
  # Currently, we don't allow prefill instance and decode instance to
1074
1180
  # have different TP sizes per DP rank, except for models using MLA.
1075
- local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
1076
- prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
1077
- if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
1181
+ if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
1078
1182
  self.target_tp_rank = (
1079
- self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
1183
+ self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
1080
1184
  )
1081
1185
  self.required_dst_info_num = 1
1082
1186
  self.required_prefill_response_num = 1
1083
1187
  self.target_tp_ranks = [self.target_tp_rank]
1084
- elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
1188
+ elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
1085
1189
  if not self.kv_mgr.is_mla_backend:
1086
1190
  logger.warning_once(
1087
1191
  "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
1088
1192
  )
1089
1193
  self.target_tp_rank = (
1090
- self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
1091
- ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
1194
+ self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
1195
+ ) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
1092
1196
  self.required_dst_info_num = (
1093
- local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
1197
+ self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
1094
1198
  )
1095
1199
  self.required_prefill_response_num = 1
1096
1200
  self.target_tp_ranks = [self.target_tp_rank]
@@ -1103,10 +1207,10 @@ class MooncakeKVReceiver(BaseKVReceiver):
1103
1207
  self.target_tp_ranks = [
1104
1208
  rank
1105
1209
  for rank in range(
1106
- (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
1107
- * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
1108
- (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
1109
- * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
1210
+ (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
1211
+ * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
1212
+ (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
1213
+ * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
1110
1214
  )
1111
1215
  ]
1112
1216
 
@@ -1116,7 +1220,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
1116
1220
  self.target_tp_rank = self.target_tp_ranks[0]
1117
1221
  self.required_dst_info_num = 1
1118
1222
  self.required_prefill_response_num = (
1119
- prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank
1223
+ self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
1120
1224
  )
1121
1225
 
1122
1226
  if self.data_parallel_rank is not None:
@@ -1136,31 +1240,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
1136
1240
  if bootstrap_key not in self.kv_mgr.connection_pool:
1137
1241
  bootstrap_infos = []
1138
1242
  for target_tp_rank in self.target_tp_ranks:
1139
- bootstrap_info = self._get_bootstrap_info_from_server(
1140
- target_tp_rank,
1141
- self.target_dp_group,
1142
- )
1143
- if bootstrap_info is not None:
1144
- if self.kv_mgr.is_mla_backend:
1145
- # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
1146
- bootstrap_info["is_dummy"] = not bool(
1147
- target_tp_rank == self.target_tp_rank
1148
- or self.target_tp_rank is None
1243
+ for target_pp_rank in range(self.prefill_pp_size):
1244
+ bootstrap_info = self._get_bootstrap_info_from_server(
1245
+ target_tp_rank, self.target_dp_group, target_pp_rank
1246
+ )
1247
+ if bootstrap_info is not None:
1248
+ if self.kv_mgr.is_mla_backend:
1249
+ # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
1250
+ bootstrap_info["is_dummy"] = not bool(
1251
+ target_tp_rank == self.target_tp_rank
1252
+ or self.target_tp_rank is None
1253
+ )
1254
+ else:
1255
+ # For non-MLA: all target_tp_ranks are selected real ranks
1256
+ bootstrap_info["is_dummy"] = False
1257
+ logger.debug(
1258
+ f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
1149
1259
  )
1260
+ bootstrap_infos.append(bootstrap_info)
1150
1261
  else:
1151
- # For non-MLA: all target_tp_ranks are selected real ranks
1152
- bootstrap_info["is_dummy"] = False
1153
- logger.debug(
1154
- f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}"
1155
- )
1156
- bootstrap_infos.append(bootstrap_info)
1157
- else:
1158
- self.kv_mgr.record_failure(
1159
- self.bootstrap_room,
1160
- f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}",
1161
- )
1162
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
1163
- return
1262
+ self.kv_mgr.record_failure(
1263
+ self.bootstrap_room,
1264
+ f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
1265
+ )
1266
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
1267
+ return
1164
1268
 
1165
1269
  self.bootstrap_infos = bootstrap_infos
1166
1270
  self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
@@ -1174,10 +1278,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
1174
1278
  self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
1175
1279
  self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
1176
1280
 
1177
- def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
1281
+ def _get_bootstrap_info_from_server(
1282
+ self, engine_rank, target_dp_group, target_pp_rank
1283
+ ):
1178
1284
  """Fetch the bootstrap info from the bootstrap server."""
1179
1285
  try:
1180
- url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
1286
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}"
1181
1287
  response = requests.get(url, timeout=5)
1182
1288
  if response.status_code == 200:
1183
1289
  bootstrap_info = response.json()
@@ -1191,24 +1297,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
1191
1297
  logger.error(f"Error fetching prefill info from bootstrap: {e}")
1192
1298
  return None
1193
1299
 
1194
- def _get_prefill_parallel_info_from_server(self) -> Tuple[int, int]:
1300
+ def _get_prefill_parallel_info_from_server(
1301
+ self,
1302
+ ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
1195
1303
  """Fetch the prefill parallel info from the bootstrap server."""
1196
1304
  try:
1197
- url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
1305
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
1198
1306
  response = requests.get(url)
1199
1307
  if response.status_code == 200:
1200
1308
  prefill_parallel_info = response.json()
1201
- return int(prefill_parallel_info["prefill_tp_size"]), int(
1202
- prefill_parallel_info["prefill_dp_size"]
1309
+ return (
1310
+ int(prefill_parallel_info["prefill_attn_tp_size"]),
1311
+ int(prefill_parallel_info["prefill_dp_size"]),
1312
+ int(prefill_parallel_info["prefill_pp_size"]),
1203
1313
  )
1204
1314
  else:
1205
1315
  logger.error(
1206
1316
  f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
1207
1317
  )
1208
- return None, None
1318
+ return None, None, None
1209
1319
  except Exception as e:
1210
1320
  logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
1211
- return None, None
1321
+ return None, None, None
1212
1322
 
1213
1323
  def _register_kv_args(self):
1214
1324
  for bootstrap_info in self.bootstrap_infos:
@@ -1218,11 +1328,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
1218
1328
  packed_aux_data_ptrs = b"".join(
1219
1329
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
1220
1330
  )
1331
+ # Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
1221
1332
  tp_rank = self.kv_mgr.kv_args.engine_rank
1222
- tp_size = self.kv_mgr.tp_size // self.kv_mgr.dp_size
1223
1333
  kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
1224
1334
  dst_tp_rank = str(tp_rank).encode("ascii")
1225
- dst_tp_size = str(tp_size).encode("ascii")
1335
+ dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode("ascii")
1226
1336
  dst_kv_item_len = str(kv_item_len).encode("ascii")
1227
1337
 
1228
1338
  sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
@@ -1236,7 +1346,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
1236
1346
  packed_kv_data_ptrs,
1237
1347
  packed_aux_data_ptrs,
1238
1348
  dst_tp_rank,
1239
- dst_tp_size,
1349
+ dst_attn_tp_size,
1240
1350
  dst_kv_item_len,
1241
1351
  ]
1242
1352
  )
@@ -1347,10 +1457,12 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1347
1457
  self.store = dict()
1348
1458
  self.lock = asyncio.Lock()
1349
1459
  self._setup_routes()
1350
- self.tp_size = None
1460
+ self.pp_size = None
1461
+ self.attn_tp_size = None
1351
1462
  self.dp_size = None
1352
- self.tp_size_per_dp_rank = None
1353
- self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
1463
+ self.prefill_port_table: Dict[
1464
+ int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
1465
+ ] = {}
1354
1466
 
1355
1467
  # Start bootstrap server
1356
1468
  self.thread = threading.Thread(target=self._run_server, daemon=True)
@@ -1380,37 +1492,45 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1380
1492
  async def _handle_route_put(self, request: web.Request):
1381
1493
  data = await request.json()
1382
1494
  role = data["role"]
1383
- tp_size = data["tp_size"]
1384
- dp_size = data["dp_size"]
1495
+ attn_tp_size = data["attn_tp_size"]
1496
+ attn_tp_rank = data["attn_tp_rank"]
1497
+ attn_dp_size = data["attn_dp_size"]
1498
+ attn_dp_rank = data["attn_dp_rank"]
1499
+ pp_size = data["pp_size"]
1500
+ pp_rank = data["pp_rank"]
1501
+ system_dp_size = data["system_dp_size"]
1502
+ system_dp_rank = data["system_dp_rank"]
1385
1503
  rank_ip = data["rank_ip"]
1386
1504
  rank_port = int(data["rank_port"])
1387
- engine_rank = int(data["engine_rank"])
1388
1505
 
1389
- if self.tp_size is None:
1390
- self.tp_size = tp_size
1506
+ if self.attn_tp_size is None:
1507
+ self.attn_tp_size = attn_tp_size
1391
1508
 
1392
1509
  if self.dp_size is None:
1393
- self.dp_size = dp_size
1510
+ self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
1394
1511
 
1395
- tp_size_per_dp_rank = tp_size // dp_size
1396
- if self.tp_size_per_dp_rank is None:
1397
- self.tp_size_per_dp_rank = tp_size_per_dp_rank
1512
+ if self.pp_size is None:
1513
+ self.pp_size = pp_size
1398
1514
 
1399
1515
  if role == "Prefill":
1400
- dp_group = engine_rank // tp_size_per_dp_rank
1401
- tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
1516
+ if system_dp_size == 1:
1517
+ dp_group = attn_dp_rank
1518
+ else:
1519
+ dp_group = system_dp_rank
1402
1520
 
1403
1521
  # Add lock to make sure thread-safe
1404
1522
  async with self.lock:
1405
1523
  if dp_group not in self.prefill_port_table:
1406
1524
  self.prefill_port_table[dp_group] = {}
1525
+ if attn_tp_rank not in self.prefill_port_table[dp_group]:
1526
+ self.prefill_port_table[dp_group][attn_tp_rank] = {}
1407
1527
 
1408
- self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
1528
+ self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
1409
1529
  "rank_ip": rank_ip,
1410
1530
  "rank_port": rank_port,
1411
1531
  }
1412
1532
  logger.debug(
1413
- f"Register prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
1533
+ f"Register prefill bootstrap: DP {dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
1414
1534
  )
1415
1535
 
1416
1536
  return web.Response(text="OK", status=200)
@@ -1418,14 +1538,20 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1418
1538
  async def _handle_route_get(self, request: web.Request):
1419
1539
  engine_rank = request.query.get("engine_rank")
1420
1540
  target_dp_group = request.query.get("target_dp_group")
1421
- if not engine_rank or not target_dp_group:
1541
+ target_pp_rank = request.query.get("target_pp_rank")
1542
+ if not engine_rank or not target_dp_group or not target_pp_rank:
1422
1543
  return web.Response(text="Missing inputs for bootstrap server.", status=400)
1423
1544
 
1424
1545
  # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
1425
- if int(engine_rank) == -1 and int(target_dp_group) == -1:
1546
+ if (
1547
+ int(engine_rank) == -1
1548
+ and int(target_dp_group) == -1
1549
+ and int(target_pp_rank) == -1
1550
+ ):
1426
1551
  prefill_parallel_info = {
1427
- "prefill_tp_size": self.tp_size,
1552
+ "prefill_attn_tp_size": self.attn_tp_size,
1428
1553
  "prefill_dp_size": self.dp_size,
1554
+ "prefill_pp_size": self.pp_size,
1429
1555
  }
1430
1556
  return web.json_response(prefill_parallel_info, status=200)
1431
1557
 
@@ -1433,7 +1559,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1433
1559
  async with self.lock:
1434
1560
  bootstrap_info = self.prefill_port_table[int(target_dp_group)][
1435
1561
  int(engine_rank)
1436
- ]
1562
+ ][int(target_pp_rank)]
1437
1563
 
1438
1564
  if bootstrap_info is not None:
1439
1565
  return web.json_response(bootstrap_info, status=200)