sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,12 @@ from sglang.srt.disaggregation.common.utils import (
34
34
  )
35
35
  from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
36
36
  from sglang.srt.disaggregation.utils import DisaggregationMode
37
+ from sglang.srt.layers.dp_attention import (
38
+ get_attention_dp_rank,
39
+ get_attention_dp_size,
40
+ get_attention_tp_rank,
41
+ get_attention_tp_size,
42
+ )
37
43
  from sglang.srt.server_args import ServerArgs
38
44
  from sglang.srt.utils import (
39
45
  format_tcp_address,
@@ -113,7 +119,7 @@ class KVArgsRegisterInfo:
113
119
  dst_kv_ptrs: list[int]
114
120
  dst_aux_ptrs: list[int]
115
121
  dst_tp_rank: int
116
- dst_tp_size: int
122
+ dst_attn_tp_size: int
117
123
  dst_kv_item_len: int
118
124
 
119
125
  @classmethod
@@ -126,7 +132,7 @@ class KVArgsRegisterInfo:
126
132
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
127
133
  dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
128
134
  dst_tp_rank=int(msg[6].decode("ascii")),
129
- dst_tp_size=int(msg[7].decode("ascii")),
135
+ dst_attn_tp_size=int(msg[7].decode("ascii")),
130
136
  dst_kv_item_len=int(msg[8].decode("ascii")),
131
137
  )
132
138
 
@@ -147,13 +153,18 @@ class MooncakeKVManager(BaseKVManager):
147
153
  # for p/d multi node infer
148
154
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
149
155
  self.dist_init_addr = server_args.dist_init_addr
150
- self.tp_size = server_args.tp_size
151
- self.dp_size = server_args.dp_size
152
- self.enable_dp_attention = server_args.enable_dp_attention
153
- if not server_args.enable_dp_attention and server_args.dp_size != 1:
154
- raise ValueError(
155
- "If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
156
- )
156
+ self.attn_tp_size = get_attention_tp_size()
157
+ self.attn_tp_rank = get_attention_tp_rank()
158
+ self.attn_dp_size = get_attention_dp_size()
159
+ self.attn_dp_rank = get_attention_dp_rank()
160
+ self.system_dp_size = (
161
+ 1 if server_args.enable_dp_attention else server_args.dp_size
162
+ )
163
+ self.system_dp_rank = (
164
+ self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
165
+ )
166
+ self.pp_size = server_args.pp_size
167
+ self.pp_rank = self.kv_args.pp_rank
157
168
  self.request_status: Dict[int, KVPoll] = {}
158
169
  self.rank_port = None
159
170
  self.server_socket = zmq.Context().socket(zmq.PULL)
@@ -221,8 +232,9 @@ class MooncakeKVManager(BaseKVManager):
221
232
  )
222
233
  self.start_decode_thread()
223
234
  self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
224
- self.prefill_tp_size_table: Dict[str, int] = {}
235
+ self.prefill_attn_tp_size_table: Dict[str, int] = {}
225
236
  self.prefill_dp_size_table: Dict[str, int] = {}
237
+ self.prefill_pp_size_table: Dict[str, int] = {}
226
238
  # If a timeout happens on the decode side, it means decode instances
227
239
  # fail to receive the KV Cache transfer done signal after bootstrapping.
228
240
  # These timeout requests should be aborted to release the tree cache.
@@ -296,15 +308,53 @@ class MooncakeKVManager(BaseKVManager):
296
308
  prefill_kv_indices, dst_kv_indices
297
309
  )
298
310
 
299
- num_layers = len(self.kv_args.kv_data_ptrs)
300
- layers_params = [
301
- (
302
- self.kv_args.kv_data_ptrs[layer_id],
303
- dst_kv_ptrs[layer_id],
304
- self.kv_args.kv_item_lens[layer_id],
305
- )
306
- for layer_id in range(num_layers)
307
- ]
311
+ layers_params = None
312
+
313
+ # pp is not supported on the decode side yet
314
+ if self.is_mla_backend:
315
+ src_kv_ptrs = self.kv_args.kv_data_ptrs
316
+ layers_per_pp_stage = len(src_kv_ptrs)
317
+ start_layer = self.pp_rank * layers_per_pp_stage
318
+ end_layer = start_layer + layers_per_pp_stage
319
+ dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
320
+ kv_item_len = self.kv_args.kv_item_lens[0]
321
+ layers_params = [
322
+ (
323
+ src_kv_ptrs[layer_id],
324
+ dst_kv_ptrs[layer_id],
325
+ kv_item_len,
326
+ )
327
+ for layer_id in range(layers_per_pp_stage)
328
+ ]
329
+ else:
330
+ num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
331
+ src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
332
+ src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
333
+ layers_per_pp_stage = len(src_k_ptrs)
334
+ start_layer = self.pp_rank * layers_per_pp_stage
335
+ end_layer = start_layer + layers_per_pp_stage
336
+ dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
337
+ dst_v_ptrs = dst_kv_ptrs[
338
+ num_kv_layers + start_layer : num_kv_layers + end_layer
339
+ ]
340
+ kv_item_len = self.kv_args.kv_item_lens[0]
341
+
342
+ layers_params = [
343
+ (
344
+ src_k_ptrs[layer_id],
345
+ dst_k_ptrs[layer_id],
346
+ kv_item_len,
347
+ )
348
+ for layer_id in range(layers_per_pp_stage)
349
+ ] + [
350
+ (
351
+ src_v_ptrs[layer_id],
352
+ dst_v_ptrs[layer_id],
353
+ kv_item_len,
354
+ )
355
+ for layer_id in range(layers_per_pp_stage)
356
+ ]
357
+ assert layers_params is not None
308
358
 
309
359
  # Worker function for processing a single layer
310
360
  def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
@@ -343,7 +393,7 @@ class MooncakeKVManager(BaseKVManager):
343
393
  dst_kv_ptrs: list[int],
344
394
  dst_kv_indices: npt.NDArray[np.int64],
345
395
  dst_tp_rank: int,
346
- dst_tp_size: int,
396
+ dst_attn_tp_size: int,
347
397
  dst_kv_item_len: int,
348
398
  executor: concurrent.futures.ThreadPoolExecutor,
349
399
  ):
@@ -356,23 +406,22 @@ class MooncakeKVManager(BaseKVManager):
356
406
  This may introduce performance overhead (increased TTFT) for long sequences.
357
407
  """
358
408
  # Extract configuration
359
- local_tp_size = self.tp_size // self.dp_size
360
- local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size
409
+ local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size
361
410
  src_kv_item_len = self.kv_args.kv_item_lens[0]
362
- dst_tp_rank_in_group = dst_tp_rank % dst_tp_size
411
+ dst_tp_rank_in_group = dst_tp_rank % dst_attn_tp_size
363
412
  num_kv_heads = self.kv_args.kv_head_num
364
413
  num_layers = len(self.kv_args.kv_data_ptrs)
365
414
  page_size = self.kv_args.page_size
366
415
 
367
416
  # Calculate head distribution
368
417
  src_heads_per_rank = num_kv_heads
369
- dst_heads_per_rank = num_kv_heads * local_tp_size // dst_tp_size
418
+ dst_heads_per_rank = num_kv_heads * self.attn_tp_size // dst_attn_tp_size
370
419
  bytes_per_head_slice_to_send = (
371
420
  dst_kv_item_len // page_size // dst_heads_per_rank
372
421
  )
373
422
 
374
423
  # Determine slicing parameters based on TP configuration
375
- if local_tp_size > dst_tp_size:
424
+ if self.attn_tp_size > dst_attn_tp_size:
376
425
  # Send KVCache from multiple prefill instances to 1 decode instance
377
426
  src_head_start_offset = 0
378
427
  num_heads_to_send = src_heads_per_rank
@@ -383,35 +432,55 @@ class MooncakeKVManager(BaseKVManager):
383
432
  num_heads_to_send = dst_heads_per_rank
384
433
  dst_head_start_offset = 0
385
434
 
386
- layers_params = []
387
- for layer_id in range(num_layers):
388
- # Calculate precise byte offset and length for the sub-slice within the token
389
- src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
390
- dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
391
- heads_bytes_per_token_to_send = (
392
- num_heads_to_send * bytes_per_head_slice_to_send
435
+ # pp is not supported on the decode side yet
436
+ num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
437
+ src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
438
+ src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
439
+ layers_per_pp_stage = len(src_k_ptrs)
440
+ start_layer = self.pp_rank * layers_per_pp_stage
441
+ end_layer = start_layer + layers_per_pp_stage
442
+ dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
443
+ dst_v_ptrs = dst_kv_ptrs[
444
+ num_kv_layers + start_layer : num_kv_layers + end_layer
445
+ ]
446
+
447
+ # Calculate precise byte offset and length for the sub-slice within the token
448
+ src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
449
+ dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
450
+ heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
451
+
452
+ # Sanity check: The data sub-slice to be sent should fit into the dst buffer.
453
+ # This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
454
+ if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
455
+ logger.error(
456
+ f"[{mooncake_session_id}] slice size ({heads_bytes_per_token_to_send}) exceeds "
457
+ f"target token slot size ({dst_kv_item_len // page_size})"
393
458
  )
459
+ return -1
394
460
 
395
- # Sanity check: The data sub-slice to be sent should fit into the dst buffer.
396
- # This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
397
- if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
398
- logger.error(
399
- f"[{mooncake_session_id}] Layer {layer_id}: "
400
- f"slice size ({heads_bytes_per_token_to_send}) exceeds "
401
- f"target token slot size ({dst_kv_item_len // page_size})"
402
- )
403
- return -1
404
- layers_params.append(
405
- (
406
- self.kv_args.kv_data_ptrs[layer_id],
407
- dst_kv_ptrs[layer_id],
408
- src_kv_item_len,
409
- dst_kv_item_len,
410
- src_head_slice_offset,
411
- dst_head_slice_offset,
412
- heads_bytes_per_token_to_send,
413
- )
461
+ layers_params = [
462
+ (
463
+ src_k_ptrs[layer_id],
464
+ dst_k_ptrs[layer_id],
465
+ src_kv_item_len,
466
+ dst_kv_item_len,
467
+ src_head_slice_offset,
468
+ dst_head_slice_offset,
469
+ heads_bytes_per_token_to_send,
414
470
  )
471
+ for layer_id in range(layers_per_pp_stage)
472
+ ] + [
473
+ (
474
+ src_v_ptrs[layer_id],
475
+ dst_v_ptrs[layer_id],
476
+ src_kv_item_len,
477
+ dst_kv_item_len,
478
+ src_head_slice_offset,
479
+ dst_head_slice_offset,
480
+ heads_bytes_per_token_to_send,
481
+ )
482
+ for layer_id in range(layers_per_pp_stage)
483
+ ]
415
484
 
416
485
  def process_layer_tp_aware(layer_params):
417
486
  (
@@ -562,9 +631,9 @@ class MooncakeKVManager(BaseKVManager):
562
631
  target_rank_registration_info: KVArgsRegisterInfo = (
563
632
  self.decode_kv_args_table[req.mooncake_session_id]
564
633
  )
565
- local_tp_size = self.tp_size // self.dp_size
566
634
  if self.is_mla_backend or (
567
- local_tp_size == target_rank_registration_info.dst_tp_size
635
+ self.attn_tp_size
636
+ == target_rank_registration_info.dst_attn_tp_size
568
637
  ):
569
638
  ret = self.send_kvcache(
570
639
  req.mooncake_session_id,
@@ -580,7 +649,7 @@ class MooncakeKVManager(BaseKVManager):
580
649
  target_rank_registration_info.dst_kv_ptrs,
581
650
  chunked_dst_kv_indice,
582
651
  target_rank_registration_info.dst_tp_rank,
583
- target_rank_registration_info.dst_tp_size,
652
+ target_rank_registration_info.dst_attn_tp_size,
584
653
  target_rank_registration_info.dst_kv_item_len,
585
654
  executor,
586
655
  )
@@ -863,11 +932,16 @@ class MooncakeKVManager(BaseKVManager):
863
932
  url = f"http://{bootstrap_server_url}/route"
864
933
  payload = {
865
934
  "role": "Prefill",
866
- "tp_size": self.tp_size,
867
- "dp_size": self.dp_size,
935
+ "attn_tp_size": self.attn_tp_size,
936
+ "attn_tp_rank": self.attn_tp_rank,
937
+ "attn_dp_size": self.attn_dp_size,
938
+ "attn_dp_rank": self.attn_dp_rank,
939
+ "pp_size": self.pp_size,
940
+ "pp_rank": self.pp_rank,
941
+ "system_dp_size": self.system_dp_size,
942
+ "system_dp_rank": self.system_dp_rank,
868
943
  "rank_ip": self.local_ip,
869
944
  "rank_port": self.rank_port,
870
- "engine_rank": self.kv_args.engine_rank,
871
945
  }
872
946
 
873
947
  try:
@@ -890,10 +964,12 @@ class MooncakeKVManager(BaseKVManager):
890
964
  ]
891
965
  for k in keys_to_remove:
892
966
  del self.connection_pool[k]
893
- if failed_bootstrap_addr in self.prefill_tp_size_table:
894
- del self.prefill_tp_size_table[failed_bootstrap_addr]
967
+ if failed_bootstrap_addr in self.prefill_attn_tp_size_table:
968
+ del self.prefill_attn_tp_size_table[failed_bootstrap_addr]
895
969
  if failed_bootstrap_addr in self.prefill_dp_size_table:
896
970
  del self.prefill_dp_size_table[failed_bootstrap_addr]
971
+ if failed_bootstrap_addr in self.prefill_pp_size_table:
972
+ del self.prefill_pp_size_table[failed_bootstrap_addr]
897
973
 
898
974
  possible_affected_rooms = self.addr_to_rooms_tracker.get(
899
975
  failed_bootstrap_addr, []
@@ -915,7 +991,7 @@ class MooncakeKVManager(BaseKVManager):
915
991
  self.update_status(room, KVPoll.Failed)
916
992
  affected_rooms.append(room)
917
993
  logger.error(
918
- f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), affected {len(affected_rooms)} requests"
994
+ f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected"
919
995
  )
920
996
 
921
997
 
@@ -1042,10 +1118,16 @@ class MooncakeKVReceiver(BaseKVReceiver):
1042
1118
  self.data_parallel_rank = data_parallel_rank
1043
1119
 
1044
1120
  if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
1045
- self.prefill_tp_size, self.prefill_dp_size = (
1046
- self._get_prefill_parallel_info_from_server()
1047
- )
1048
- if self.prefill_tp_size is None or self.prefill_dp_size is None:
1121
+ (
1122
+ self.prefill_attn_tp_size,
1123
+ self.prefill_dp_size,
1124
+ self.prefill_pp_size,
1125
+ ) = self._get_prefill_parallel_info_from_server()
1126
+ if (
1127
+ self.prefill_attn_tp_size is None
1128
+ or self.prefill_dp_size is None
1129
+ or self.prefill_pp_size is None
1130
+ ):
1049
1131
  self.kv_mgr.record_failure(
1050
1132
  self.bootstrap_room,
1051
1133
  f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
@@ -1054,43 +1136,47 @@ class MooncakeKVReceiver(BaseKVReceiver):
1054
1136
  return
1055
1137
  else:
1056
1138
  logger.debug(
1057
- f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_tp_size}"
1139
+ f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
1058
1140
  )
1059
- self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
1060
- self.prefill_tp_size
1141
+ self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
1142
+ self.prefill_attn_tp_size
1061
1143
  )
1062
1144
  self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
1063
1145
  self.prefill_dp_size
1064
1146
  )
1147
+ self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
1148
+ self.prefill_pp_size
1149
+ )
1065
1150
  else:
1066
- self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
1151
+ self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
1067
1152
  self.bootstrap_addr
1068
1153
  ]
1069
1154
  self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
1070
1155
  self.bootstrap_addr
1071
1156
  ]
1157
+ self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
1158
+ self.bootstrap_addr
1159
+ ]
1072
1160
 
1073
1161
  # Currently, we don't allow prefill instance and decode instance to
1074
1162
  # have different TP sizes per DP rank, except for models using MLA.
1075
- local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
1076
- prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
1077
- if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
1163
+ if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
1078
1164
  self.target_tp_rank = (
1079
- self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
1165
+ self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
1080
1166
  )
1081
1167
  self.required_dst_info_num = 1
1082
1168
  self.required_prefill_response_num = 1
1083
1169
  self.target_tp_ranks = [self.target_tp_rank]
1084
- elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
1170
+ elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
1085
1171
  if not self.kv_mgr.is_mla_backend:
1086
1172
  logger.warning_once(
1087
1173
  "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
1088
1174
  )
1089
1175
  self.target_tp_rank = (
1090
- self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
1091
- ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
1176
+ self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
1177
+ ) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
1092
1178
  self.required_dst_info_num = (
1093
- local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
1179
+ self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
1094
1180
  )
1095
1181
  self.required_prefill_response_num = 1
1096
1182
  self.target_tp_ranks = [self.target_tp_rank]
@@ -1103,10 +1189,10 @@ class MooncakeKVReceiver(BaseKVReceiver):
1103
1189
  self.target_tp_ranks = [
1104
1190
  rank
1105
1191
  for rank in range(
1106
- (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
1107
- * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
1108
- (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
1109
- * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
1192
+ (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
1193
+ * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
1194
+ (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
1195
+ * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
1110
1196
  )
1111
1197
  ]
1112
1198
 
@@ -1116,7 +1202,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
1116
1202
  self.target_tp_rank = self.target_tp_ranks[0]
1117
1203
  self.required_dst_info_num = 1
1118
1204
  self.required_prefill_response_num = (
1119
- prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank
1205
+ self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
1120
1206
  )
1121
1207
 
1122
1208
  if self.data_parallel_rank is not None:
@@ -1136,31 +1222,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
1136
1222
  if bootstrap_key not in self.kv_mgr.connection_pool:
1137
1223
  bootstrap_infos = []
1138
1224
  for target_tp_rank in self.target_tp_ranks:
1139
- bootstrap_info = self._get_bootstrap_info_from_server(
1140
- target_tp_rank,
1141
- self.target_dp_group,
1142
- )
1143
- if bootstrap_info is not None:
1144
- if self.kv_mgr.is_mla_backend:
1145
- # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
1146
- bootstrap_info["is_dummy"] = not bool(
1147
- target_tp_rank == self.target_tp_rank
1148
- or self.target_tp_rank is None
1225
+ for target_pp_rank in range(self.prefill_pp_size):
1226
+ bootstrap_info = self._get_bootstrap_info_from_server(
1227
+ target_tp_rank, self.target_dp_group, target_pp_rank
1228
+ )
1229
+ if bootstrap_info is not None:
1230
+ if self.kv_mgr.is_mla_backend:
1231
+ # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
1232
+ bootstrap_info["is_dummy"] = not bool(
1233
+ target_tp_rank == self.target_tp_rank
1234
+ or self.target_tp_rank is None
1235
+ )
1236
+ else:
1237
+ # For non-MLA: all target_tp_ranks are selected real ranks
1238
+ bootstrap_info["is_dummy"] = False
1239
+ logger.debug(
1240
+ f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
1149
1241
  )
1242
+ bootstrap_infos.append(bootstrap_info)
1150
1243
  else:
1151
- # For non-MLA: all target_tp_ranks are selected real ranks
1152
- bootstrap_info["is_dummy"] = False
1153
- logger.debug(
1154
- f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}"
1155
- )
1156
- bootstrap_infos.append(bootstrap_info)
1157
- else:
1158
- self.kv_mgr.record_failure(
1159
- self.bootstrap_room,
1160
- f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}",
1161
- )
1162
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
1163
- return
1244
+ self.kv_mgr.record_failure(
1245
+ self.bootstrap_room,
1246
+ f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
1247
+ )
1248
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
1249
+ return
1164
1250
 
1165
1251
  self.bootstrap_infos = bootstrap_infos
1166
1252
  self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
@@ -1174,10 +1260,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
1174
1260
  self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
1175
1261
  self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
1176
1262
 
1177
- def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
1263
+ def _get_bootstrap_info_from_server(
1264
+ self, engine_rank, target_dp_group, target_pp_rank
1265
+ ):
1178
1266
  """Fetch the bootstrap info from the bootstrap server."""
1179
1267
  try:
1180
- url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
1268
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}"
1181
1269
  response = requests.get(url, timeout=5)
1182
1270
  if response.status_code == 200:
1183
1271
  bootstrap_info = response.json()
@@ -1191,24 +1279,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
1191
1279
  logger.error(f"Error fetching prefill info from bootstrap: {e}")
1192
1280
  return None
1193
1281
 
1194
- def _get_prefill_parallel_info_from_server(self) -> Tuple[int, int]:
1282
+ def _get_prefill_parallel_info_from_server(
1283
+ self,
1284
+ ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
1195
1285
  """Fetch the prefill parallel info from the bootstrap server."""
1196
1286
  try:
1197
- url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
1287
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
1198
1288
  response = requests.get(url)
1199
1289
  if response.status_code == 200:
1200
1290
  prefill_parallel_info = response.json()
1201
- return int(prefill_parallel_info["prefill_tp_size"]), int(
1202
- prefill_parallel_info["prefill_dp_size"]
1291
+ return (
1292
+ int(prefill_parallel_info["prefill_attn_tp_size"]),
1293
+ int(prefill_parallel_info["prefill_dp_size"]),
1294
+ int(prefill_parallel_info["prefill_pp_size"]),
1203
1295
  )
1204
1296
  else:
1205
1297
  logger.error(
1206
1298
  f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
1207
1299
  )
1208
- return None, None
1300
+ return None, None, None
1209
1301
  except Exception as e:
1210
1302
  logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
1211
- return None, None
1303
+ return None, None, None
1212
1304
 
1213
1305
  def _register_kv_args(self):
1214
1306
  for bootstrap_info in self.bootstrap_infos:
@@ -1218,11 +1310,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
1218
1310
  packed_aux_data_ptrs = b"".join(
1219
1311
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
1220
1312
  )
1313
+ # Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
1221
1314
  tp_rank = self.kv_mgr.kv_args.engine_rank
1222
- tp_size = self.kv_mgr.tp_size // self.kv_mgr.dp_size
1223
1315
  kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
1224
1316
  dst_tp_rank = str(tp_rank).encode("ascii")
1225
- dst_tp_size = str(tp_size).encode("ascii")
1317
+ dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode("ascii")
1226
1318
  dst_kv_item_len = str(kv_item_len).encode("ascii")
1227
1319
 
1228
1320
  sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
@@ -1236,7 +1328,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
1236
1328
  packed_kv_data_ptrs,
1237
1329
  packed_aux_data_ptrs,
1238
1330
  dst_tp_rank,
1239
- dst_tp_size,
1331
+ dst_attn_tp_size,
1240
1332
  dst_kv_item_len,
1241
1333
  ]
1242
1334
  )
@@ -1347,10 +1439,12 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1347
1439
  self.store = dict()
1348
1440
  self.lock = asyncio.Lock()
1349
1441
  self._setup_routes()
1350
- self.tp_size = None
1442
+ self.pp_size = None
1443
+ self.attn_tp_size = None
1351
1444
  self.dp_size = None
1352
- self.tp_size_per_dp_rank = None
1353
- self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
1445
+ self.prefill_port_table: Dict[
1446
+ int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
1447
+ ] = {}
1354
1448
 
1355
1449
  # Start bootstrap server
1356
1450
  self.thread = threading.Thread(target=self._run_server, daemon=True)
@@ -1380,37 +1474,45 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1380
1474
  async def _handle_route_put(self, request: web.Request):
1381
1475
  data = await request.json()
1382
1476
  role = data["role"]
1383
- tp_size = data["tp_size"]
1384
- dp_size = data["dp_size"]
1477
+ attn_tp_size = data["attn_tp_size"]
1478
+ attn_tp_rank = data["attn_tp_rank"]
1479
+ attn_dp_size = data["attn_dp_size"]
1480
+ attn_dp_rank = data["attn_dp_rank"]
1481
+ pp_size = data["pp_size"]
1482
+ pp_rank = data["pp_rank"]
1483
+ system_dp_size = data["system_dp_size"]
1484
+ system_dp_rank = data["system_dp_rank"]
1385
1485
  rank_ip = data["rank_ip"]
1386
1486
  rank_port = int(data["rank_port"])
1387
- engine_rank = int(data["engine_rank"])
1388
1487
 
1389
- if self.tp_size is None:
1390
- self.tp_size = tp_size
1488
+ if self.attn_tp_size is None:
1489
+ self.attn_tp_size = attn_tp_size
1391
1490
 
1392
1491
  if self.dp_size is None:
1393
- self.dp_size = dp_size
1492
+ self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
1394
1493
 
1395
- tp_size_per_dp_rank = tp_size // dp_size
1396
- if self.tp_size_per_dp_rank is None:
1397
- self.tp_size_per_dp_rank = tp_size_per_dp_rank
1494
+ if self.pp_size is None:
1495
+ self.pp_size = pp_size
1398
1496
 
1399
1497
  if role == "Prefill":
1400
- dp_group = engine_rank // tp_size_per_dp_rank
1401
- tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
1498
+ if system_dp_size == 1:
1499
+ dp_group = attn_dp_rank
1500
+ else:
1501
+ dp_group = system_dp_rank
1402
1502
 
1403
1503
  # Add lock to make sure thread-safe
1404
1504
  async with self.lock:
1405
1505
  if dp_group not in self.prefill_port_table:
1406
1506
  self.prefill_port_table[dp_group] = {}
1507
+ if attn_tp_rank not in self.prefill_port_table[dp_group]:
1508
+ self.prefill_port_table[dp_group][attn_tp_rank] = {}
1407
1509
 
1408
- self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
1510
+ self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
1409
1511
  "rank_ip": rank_ip,
1410
1512
  "rank_port": rank_port,
1411
1513
  }
1412
1514
  logger.debug(
1413
- f"Register prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
1515
+ f"Register prefill bootstrap: DP {dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
1414
1516
  )
1415
1517
 
1416
1518
  return web.Response(text="OK", status=200)
@@ -1418,14 +1520,20 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1418
1520
  async def _handle_route_get(self, request: web.Request):
1419
1521
  engine_rank = request.query.get("engine_rank")
1420
1522
  target_dp_group = request.query.get("target_dp_group")
1421
- if not engine_rank or not target_dp_group:
1523
+ target_pp_rank = request.query.get("target_pp_rank")
1524
+ if not engine_rank or not target_dp_group or not target_pp_rank:
1422
1525
  return web.Response(text="Missing inputs for bootstrap server.", status=400)
1423
1526
 
1424
1527
  # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
1425
- if int(engine_rank) == -1 and int(target_dp_group) == -1:
1528
+ if (
1529
+ int(engine_rank) == -1
1530
+ and int(target_dp_group) == -1
1531
+ and int(target_pp_rank) == -1
1532
+ ):
1426
1533
  prefill_parallel_info = {
1427
- "prefill_tp_size": self.tp_size,
1534
+ "prefill_attn_tp_size": self.attn_tp_size,
1428
1535
  "prefill_dp_size": self.dp_size,
1536
+ "prefill_pp_size": self.pp_size,
1429
1537
  }
1430
1538
  return web.json_response(prefill_parallel_info, status=200)
1431
1539
 
@@ -1433,7 +1541,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1433
1541
  async with self.lock:
1434
1542
  bootstrap_info = self.prefill_port_table[int(target_dp_group)][
1435
1543
  int(engine_rank)
1436
- ]
1544
+ ][int(target_pp_rank)]
1437
1545
 
1438
1546
  if bootstrap_info is not None:
1439
1547
  return web.json_response(bootstrap_info, status=200)
@@ -103,6 +103,8 @@ class PrefillBootstrapQueue:
103
103
  kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
104
104
  kv_args = kv_args_class()
105
105
  kv_args.engine_rank = self.tp_rank
106
+ kv_args.pp_rank = self.pp_rank
107
+ kv_args.system_dp_rank = self.scheduler.dp_rank
106
108
  kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
107
109
  kv_args.prefill_pp_size = self.pp_size
108
110
  kv_data_ptrs, kv_data_lens, kv_item_lens = (