sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,18 @@ from sglang.srt.disaggregation.base.conn import (
22
22
  KVPoll,
23
23
  )
24
24
  from sglang.srt.disaggregation.utils import DisaggregationMode
25
+ from sglang.srt.distributed import get_pp_group
26
+ from sglang.srt.layers.dp_attention import (
27
+ get_attention_dp_rank,
28
+ get_attention_dp_size,
29
+ get_attention_tp_rank,
30
+ get_attention_tp_size,
31
+ )
25
32
  from sglang.srt.server_args import ServerArgs
26
33
  from sglang.srt.utils import (
27
34
  format_tcp_address,
28
35
  get_free_port,
29
- get_ip,
30
- get_local_ip_by_remote,
36
+ get_local_ip_auto,
31
37
  is_valid_ipv6_address,
32
38
  maybe_wrap_ipv6_address,
33
39
  )
@@ -50,30 +56,49 @@ class CommonKVManager(BaseKVManager):
50
56
  self.bootstrap_host = server_args.host
51
57
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
52
58
  self.dist_init_addr = server_args.dist_init_addr
53
- self.tp_size = server_args.tp_size
54
- self.dp_size = server_args.dp_size
55
- self.enable_dp_attention = server_args.enable_dp_attention
56
- if not server_args.enable_dp_attention and server_args.dp_size != 1:
57
- raise ValueError(
58
- "If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
59
- )
60
-
59
+ self.attn_tp_size = get_attention_tp_size()
60
+ self.attn_tp_rank = get_attention_tp_rank()
61
+ self.attn_dp_size = get_attention_dp_size()
62
+ self.attn_dp_rank = get_attention_dp_rank()
63
+ self.system_dp_size = (
64
+ 1 if server_args.enable_dp_attention else server_args.dp_size
65
+ )
66
+ self.system_dp_rank = (
67
+ self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
68
+ )
69
+ self.pp_size = server_args.pp_size
70
+ self.pp_rank = self.kv_args.pp_rank
61
71
  self.rank_port = get_free_port()
72
+ self.local_ip = get_local_ip_auto()
73
+ self.server_socket = zmq.Context().socket(zmq.PULL)
74
+ if is_valid_ipv6_address(self.local_ip):
75
+ self.server_socket.setsockopt(zmq.IPV6, 1)
76
+ self.request_status: Dict[int, KVPoll] = {}
77
+
62
78
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
63
79
  self._register_to_bootstrap()
80
+ self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
81
+ self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
82
+ self.pp_group = get_pp_group()
64
83
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
65
84
  self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
66
- self.prefill_tp_size_table: Dict[str, int] = {}
85
+ self.connection_lock = threading.Lock()
86
+ self.required_prefill_response_num_table: Dict[int, int] = {}
87
+ self.prefill_attn_tp_size_table: Dict[str, int] = {}
67
88
  self.prefill_dp_size_table: Dict[str, int] = {}
89
+ self.prefill_pp_size_table: Dict[str, int] = {}
68
90
  else:
69
91
  raise ValueError(
70
92
  f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
71
93
  )
72
94
 
95
+ def _bind_server_socket(self):
96
+ self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
97
+
73
98
  def _register_to_bootstrap(self):
74
99
  """Register KVSender to bootstrap server via HTTP POST."""
75
100
  if self.dist_init_addr:
76
- # multi node: bootstrap server's host is dist_init_addr
101
+ # Multi-node case: bootstrap server's host is dist_init_addr
77
102
  if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
78
103
  if self.dist_init_addr.endswith("]"):
79
104
  host = self.dist_init_addr
@@ -82,7 +107,7 @@ class CommonKVManager(BaseKVManager):
82
107
  else:
83
108
  host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
84
109
  else:
85
- # single node: bootstrap server's host is same as http server's host
110
+ # Single-node case: bootstrap server's host is the same as http server's host
86
111
  host = self.bootstrap_host
87
112
  host = maybe_wrap_ipv6_address(host)
88
113
 
@@ -90,23 +115,30 @@ class CommonKVManager(BaseKVManager):
90
115
  url = f"http://{bootstrap_server_url}/route"
91
116
  payload = {
92
117
  "role": "Prefill",
93
- "tp_size": self.tp_size,
94
- "dp_size": self.dp_size,
95
- "rank_ip": get_local_ip_by_remote(),
118
+ "attn_tp_size": self.attn_tp_size,
119
+ "attn_tp_rank": self.attn_tp_rank,
120
+ "attn_dp_size": self.attn_dp_size,
121
+ "attn_dp_rank": self.attn_dp_rank,
122
+ "pp_size": self.pp_size,
123
+ "pp_rank": self.pp_rank,
124
+ "system_dp_size": self.system_dp_size,
125
+ "system_dp_rank": self.system_dp_rank,
126
+ "rank_ip": self.local_ip,
96
127
  "rank_port": self.rank_port,
97
- "engine_rank": self.kv_args.engine_rank,
98
128
  }
99
129
 
100
130
  try:
101
- response = requests.put(url, json=payload)
131
+ response = requests.put(url, json=payload, timeout=5)
102
132
  if response.status_code == 200:
103
133
  logger.debug("Prefill successfully registered to bootstrap server.")
104
134
  else:
105
135
  logger.error(
106
- f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
136
+ f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
107
137
  )
108
138
  except Exception as e:
109
- logger.error(f"Prefill Failed to register to bootstrap server: {e}")
139
+ logger.error(
140
+ f"Prefill instance failed to register to bootstrap server: {e}"
141
+ )
110
142
 
111
143
  @cache
112
144
  def _connect(self, endpoint: str, is_ipv6: bool = False):
@@ -116,6 +148,68 @@ class CommonKVManager(BaseKVManager):
116
148
  socket.connect(endpoint)
117
149
  return socket
118
150
 
151
+ def get_mha_kv_ptrs_with_pp(
152
+ self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
153
+ ) -> Tuple[List[int], List[int], List[int], List[int], int]:
154
+ # pp is not supported on the decode side yet
155
+ start_layer = self.kv_args.prefill_start_layer
156
+ num_kv_layers = len(src_kv_ptrs) // 2
157
+ end_layer = start_layer + num_kv_layers
158
+ dst_num_total_layers = len(dst_kv_ptrs) // 2
159
+ src_k_ptrs = src_kv_ptrs[:num_kv_layers]
160
+ src_v_ptrs = src_kv_ptrs[num_kv_layers:]
161
+ dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
162
+ dst_v_ptrs = dst_kv_ptrs[
163
+ dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
164
+ ]
165
+ layers_current_pp_stage = len(src_k_ptrs)
166
+ return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
167
+
168
+ def get_mla_kv_ptrs_with_pp(
169
+ self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
170
+ ) -> Tuple[List[int], List[int], int]:
171
+ # pp is not supported on the decode side yet
172
+ start_layer = self.kv_args.prefill_start_layer
173
+ end_layer = start_layer + len(src_kv_ptrs)
174
+ sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
175
+ layers_current_pp_stage = len(src_kv_ptrs)
176
+ return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
177
+
178
+
179
+ class CommonKVSender(BaseKVSender):
180
+
181
+ def __init__(
182
+ self,
183
+ mgr: BaseKVManager,
184
+ bootstrap_addr: str,
185
+ bootstrap_room: int,
186
+ dest_tp_ranks: List[int],
187
+ pp_rank: int,
188
+ ):
189
+ self.kv_mgr = mgr
190
+ self.bootstrap_room = bootstrap_room
191
+ self.aux_index = None
192
+ self.bootstrap_server_url = bootstrap_addr
193
+ # inner state
194
+ self.curr_idx = 0
195
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
196
+
197
+ def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
198
+ self.num_kv_indices = num_kv_indices
199
+ self.aux_index = aux_index
200
+
201
+ def send(
202
+ self,
203
+ kv_indices: npt.NDArray[np.int32],
204
+ ):
205
+ pass
206
+
207
+ def poll(self) -> KVPoll:
208
+ pass
209
+
210
+ def failure_exception(self):
211
+ raise Exception("Fake KVReceiver Exception")
212
+
119
213
 
120
214
  class CommonKVReceiver(BaseKVReceiver):
121
215
  _ctx = zmq.Context()
@@ -133,61 +227,88 @@ class CommonKVReceiver(BaseKVReceiver):
133
227
  self.bootstrap_room = bootstrap_room
134
228
  self.bootstrap_addr = bootstrap_addr
135
229
  self.kv_mgr = mgr
230
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
136
231
 
137
232
  if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
138
- self.prefill_tp_size, self.prefill_dp_size = (
139
- self._get_prefill_dp_size_from_server()
140
- )
141
- if self.prefill_tp_size is None or self.prefill_dp_size is None:
142
- logger.error(
143
- f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
233
+ (
234
+ self.prefill_attn_tp_size,
235
+ self.prefill_dp_size,
236
+ self.prefill_pp_size,
237
+ ) = self._get_prefill_parallel_info_from_server()
238
+ if (
239
+ self.prefill_attn_tp_size is None
240
+ or self.prefill_dp_size is None
241
+ or self.prefill_pp_size is None
242
+ ):
243
+ self.kv_mgr.record_failure(
244
+ self.bootstrap_room,
245
+ f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
144
246
  )
247
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
248
+ return
145
249
  else:
146
- self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
147
- self.prefill_tp_size
250
+ logger.debug(
251
+ f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
252
+ )
253
+ self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
254
+ self.prefill_attn_tp_size
148
255
  )
149
256
  self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
150
257
  self.prefill_dp_size
151
258
  )
259
+ self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
260
+ self.prefill_pp_size
261
+ )
152
262
  else:
153
- self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
263
+ self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
154
264
  self.bootstrap_addr
155
265
  ]
156
266
  self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
157
267
  self.bootstrap_addr
158
268
  ]
269
+ self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
270
+ self.bootstrap_addr
271
+ ]
159
272
 
160
273
  # Currently, we don't allow prefill instance and decode instance to
161
274
  # have different TP sizes per DP rank, except for models using MLA.
162
- local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
163
- prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
164
- if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
275
+ if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
165
276
  self.target_tp_rank = (
166
- self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
277
+ self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
167
278
  )
168
279
  self.required_dst_info_num = 1
280
+ self.required_prefill_response_num = 1 * (
281
+ self.prefill_pp_size // self.kv_mgr.pp_size
282
+ )
169
283
  self.target_tp_ranks = [self.target_tp_rank]
170
- elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
284
+ elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
285
+ if not self.kv_mgr.is_mla_backend:
286
+ logger.warning_once(
287
+ "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
288
+ )
171
289
  self.target_tp_rank = (
172
- self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
173
- ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
290
+ self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
291
+ ) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
174
292
  self.required_dst_info_num = (
175
- local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
293
+ self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
294
+ )
295
+ self.required_prefill_response_num = 1 * (
296
+ self.prefill_pp_size // self.kv_mgr.pp_size
176
297
  )
177
298
  self.target_tp_ranks = [self.target_tp_rank]
178
299
  else:
179
- assert (
180
- self.kv_mgr.is_mla_backend
181
- ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
182
-
300
+ if not self.kv_mgr.is_mla_backend:
301
+ logger.warning_once(
302
+ "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
303
+ )
183
304
  # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
184
305
  self.target_tp_ranks = [
185
306
  rank
186
307
  for rank in range(
187
- (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
188
- * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
189
- (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
190
- * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
308
+ (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
309
+ * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
310
+ (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
311
+ * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
191
312
  )
192
313
  ]
193
314
 
@@ -196,6 +317,14 @@ class CommonKVReceiver(BaseKVReceiver):
196
317
  # or the KVPoll will never be set correctly
197
318
  self.target_tp_rank = self.target_tp_ranks[0]
198
319
  self.required_dst_info_num = 1
320
+ if self.kv_mgr.is_mla_backend:
321
+ self.required_prefill_response_num = (
322
+ self.prefill_pp_size // self.kv_mgr.pp_size
323
+ )
324
+ else:
325
+ self.required_prefill_response_num = (
326
+ self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
327
+ ) * (self.prefill_pp_size // self.kv_mgr.pp_size)
199
328
 
200
329
  if prefill_dp_rank is not None:
201
330
  logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
@@ -206,6 +335,9 @@ class CommonKVReceiver(BaseKVReceiver):
206
335
  # FIXME: alias here: target_dp_group -> prefill_dp_rank
207
336
  self.target_dp_group = self.prefill_dp_rank
208
337
 
338
+ self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
339
+ self.required_prefill_response_num
340
+ )
209
341
  # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
210
342
  bootstrap_key = (
211
343
  f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
@@ -214,41 +346,49 @@ class CommonKVReceiver(BaseKVReceiver):
214
346
  if bootstrap_key not in self.kv_mgr.connection_pool:
215
347
  bootstrap_infos = []
216
348
  for target_tp_rank in self.target_tp_ranks:
217
- bootstrap_info = self._get_bootstrap_info_from_server(
218
- target_tp_rank,
219
- self.target_dp_group,
220
- )
221
- if bootstrap_info is not None:
222
- # NOTE: only support MLA for now: select one prefill rank as real rank
223
- bootstrap_info["is_dummy"] = not bool(
224
- target_tp_rank == self.target_tp_rank
225
- or self.target_tp_rank is None
226
- )
227
- bootstrap_infos.append(bootstrap_info)
228
- else:
229
- logger.error(
230
- f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
349
+ for target_pp_rank in range(self.prefill_pp_size):
350
+ bootstrap_info = self._get_bootstrap_info_from_server(
351
+ target_tp_rank, self.target_dp_group, target_pp_rank
231
352
  )
353
+ if bootstrap_info is not None:
354
+ if self.kv_mgr.is_mla_backend:
355
+ # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
356
+ bootstrap_info["is_dummy"] = not bool(
357
+ target_tp_rank == self.target_tp_rank
358
+ or self.target_tp_rank is None
359
+ )
360
+ else:
361
+ # For non-MLA: all target_tp_ranks are selected real ranks
362
+ bootstrap_info["is_dummy"] = False
363
+ logger.debug(
364
+ f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
365
+ )
366
+ bootstrap_infos.append(bootstrap_info)
367
+ else:
368
+ self.kv_mgr.record_failure(
369
+ self.bootstrap_room,
370
+ f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
371
+ )
372
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
373
+ return
374
+
232
375
  self.bootstrap_infos = bootstrap_infos
376
+ self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
233
377
 
234
- if len(self.bootstrap_infos) == 0:
235
- logger.error(
236
- f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
237
- )
238
- else:
239
- self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
240
- # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
241
- self._register_kv_args()
378
+ # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
379
+ self._register_kv_args()
242
380
  else:
243
381
  self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
244
382
 
245
383
  assert len(self.bootstrap_infos) > 0
246
384
 
247
- def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
385
+ def _get_bootstrap_info_from_server(
386
+ self, engine_rank, target_dp_group, target_pp_rank
387
+ ):
248
388
  """Fetch the bootstrap info from the bootstrap server."""
249
389
  try:
250
- url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
251
- response = requests.get(url)
390
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}"
391
+ response = requests.get(url, timeout=5)
252
392
  if response.status_code == 200:
253
393
  bootstrap_info = response.json()
254
394
  return bootstrap_info
@@ -261,24 +401,28 @@ class CommonKVReceiver(BaseKVReceiver):
261
401
  logger.error(f"Error fetching prefill info from bootstrap: {e}")
262
402
  return None
263
403
 
264
- def _get_prefill_dp_size_from_server(self) -> int:
404
+ def _get_prefill_parallel_info_from_server(
405
+ self,
406
+ ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
265
407
  """Fetch the prefill parallel info from the bootstrap server."""
266
408
  try:
267
- url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
409
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
268
410
  response = requests.get(url)
269
411
  if response.status_code == 200:
270
412
  prefill_parallel_info = response.json()
271
- return int(prefill_parallel_info["prefill_tp_size"]), int(
272
- prefill_parallel_info["prefill_dp_size"]
413
+ return (
414
+ int(prefill_parallel_info["prefill_attn_tp_size"]),
415
+ int(prefill_parallel_info["prefill_dp_size"]),
416
+ int(prefill_parallel_info["prefill_pp_size"]),
273
417
  )
274
418
  else:
275
419
  logger.error(
276
420
  f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
277
421
  )
278
- return None
422
+ return None, None, None
279
423
  except Exception as e:
280
424
  logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
281
- return None
425
+ return None, None, None
282
426
 
283
427
  @classmethod
284
428
  def _connect(cls, endpoint: str, is_ipv6: bool = False):
@@ -317,10 +461,12 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
317
461
  self.store = dict()
318
462
  self.lock = asyncio.Lock()
319
463
  self._setup_routes()
320
- self.tp_size = None
464
+ self.pp_size = None
465
+ self.attn_tp_size = None
321
466
  self.dp_size = None
322
- self.tp_size_per_dp_rank = None
323
- self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
467
+ self.prefill_port_table: Dict[
468
+ int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
469
+ ] = {}
324
470
 
325
471
  # Start bootstrap server
326
472
  self.thread = threading.Thread(target=self._run_server, daemon=True)
@@ -331,6 +477,10 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
331
477
 
332
478
  def _setup_routes(self):
333
479
  self.app.router.add_route("*", "/route", self._handle_route)
480
+ self.app.router.add_get("/health", self._handle_health_check)
481
+
482
+ async def _handle_health_check(self, request):
483
+ return web.Response(text="OK", status=200)
334
484
 
335
485
  async def _handle_route(self, request: web.Request):
336
486
  method = request.method
@@ -346,37 +496,45 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
346
496
  async def _handle_route_put(self, request: web.Request):
347
497
  data = await request.json()
348
498
  role = data["role"]
349
- tp_size = data["tp_size"]
350
- dp_size = data["dp_size"]
499
+ attn_tp_size = data["attn_tp_size"]
500
+ attn_tp_rank = data["attn_tp_rank"]
501
+ attn_dp_size = data["attn_dp_size"]
502
+ attn_dp_rank = data["attn_dp_rank"]
503
+ pp_size = data["pp_size"]
504
+ pp_rank = data["pp_rank"]
505
+ system_dp_size = data["system_dp_size"]
506
+ system_dp_rank = data["system_dp_rank"]
351
507
  rank_ip = data["rank_ip"]
352
508
  rank_port = int(data["rank_port"])
353
- engine_rank = int(data["engine_rank"])
354
509
 
355
- if self.tp_size is None:
356
- self.tp_size = tp_size
510
+ if self.attn_tp_size is None:
511
+ self.attn_tp_size = attn_tp_size
357
512
 
358
513
  if self.dp_size is None:
359
- self.dp_size = dp_size
514
+ self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
360
515
 
361
- tp_size_per_dp_rank = tp_size // dp_size
362
- if self.tp_size_per_dp_rank == None:
363
- self.tp_size_per_dp_rank = tp_size_per_dp_rank
516
+ if self.pp_size is None:
517
+ self.pp_size = pp_size
364
518
 
365
- # Add lock to make sure thread-safe
366
519
  if role == "Prefill":
367
- dp_group = engine_rank // tp_size_per_dp_rank
368
- tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
520
+ if system_dp_size == 1:
521
+ dp_group = attn_dp_rank
522
+ else:
523
+ dp_group = system_dp_rank
369
524
 
525
+ # Add lock to make sure thread-safe
370
526
  async with self.lock:
371
527
  if dp_group not in self.prefill_port_table:
372
528
  self.prefill_port_table[dp_group] = {}
529
+ if attn_tp_rank not in self.prefill_port_table[dp_group]:
530
+ self.prefill_port_table[dp_group][attn_tp_rank] = {}
373
531
 
374
- self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
532
+ self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
375
533
  "rank_ip": rank_ip,
376
534
  "rank_port": rank_port,
377
535
  }
378
536
  logger.debug(
379
- f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
537
+ f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
380
538
  )
381
539
 
382
540
  return web.Response(text="OK", status=200)
@@ -384,14 +542,20 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
384
542
  async def _handle_route_get(self, request: web.Request):
385
543
  engine_rank = request.query.get("engine_rank")
386
544
  target_dp_group = request.query.get("target_dp_group")
387
- if not engine_rank or not target_dp_group:
545
+ target_pp_rank = request.query.get("target_pp_rank")
546
+ if not engine_rank or not target_dp_group or not target_pp_rank:
388
547
  return web.Response(text="Missing inputs for bootstrap server.", status=400)
389
548
 
390
549
  # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
391
- if int(engine_rank) == -1 and int(target_dp_group) == -1:
550
+ if (
551
+ int(engine_rank) == -1
552
+ and int(target_dp_group) == -1
553
+ and int(target_pp_rank) == -1
554
+ ):
392
555
  prefill_parallel_info = {
393
- "prefill_tp_size": self.tp_size,
556
+ "prefill_attn_tp_size": self.attn_tp_size,
394
557
  "prefill_dp_size": self.dp_size,
558
+ "prefill_pp_size": self.pp_size,
395
559
  }
396
560
  return web.json_response(prefill_parallel_info, status=200)
397
561
 
@@ -399,7 +563,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
399
563
  async with self.lock:
400
564
  bootstrap_info = self.prefill_port_table[int(target_dp_group)][
401
565
  int(engine_rank)
402
- ]
566
+ ][int(target_pp_rank)]
403
567
 
404
568
  if bootstrap_info is not None:
405
569
  return web.json_response(bootstrap_info, status=200)
@@ -412,7 +576,11 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
412
576
  self._loop = asyncio.new_event_loop()
413
577
  asyncio.set_event_loop(self._loop)
414
578
 
415
- self._runner = web.AppRunner(self.app)
579
+ access_log = None
580
+ if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
581
+ access_log = self.app.logger
582
+
583
+ self._runner = web.AppRunner(self.app, access_log=access_log)
416
584
  self._loop.run_until_complete(self._runner.setup())
417
585
 
418
586
  site = web.TCPSite(self._runner, host=self.host, port=self.port)