sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,18 @@ from sglang.srt.disaggregation.base.conn import (
22
22
  KVPoll,
23
23
  )
24
24
  from sglang.srt.disaggregation.utils import DisaggregationMode
25
+ from sglang.srt.distributed import get_pp_group
26
+ from sglang.srt.layers.dp_attention import (
27
+ get_attention_dp_rank,
28
+ get_attention_dp_size,
29
+ get_attention_tp_rank,
30
+ get_attention_tp_size,
31
+ )
25
32
  from sglang.srt.server_args import ServerArgs
26
33
  from sglang.srt.utils import (
27
34
  format_tcp_address,
28
35
  get_free_port,
29
- get_ip,
30
- get_local_ip_by_remote,
36
+ get_local_ip_auto,
31
37
  is_valid_ipv6_address,
32
38
  maybe_wrap_ipv6_address,
33
39
  )
@@ -47,31 +53,52 @@ class CommonKVManager(BaseKVManager):
47
53
  self.is_mla_backend = is_mla_backend
48
54
  self.disaggregation_mode = disaggregation_mode
49
55
  # for p/d multi node infer
56
+ self.bootstrap_host = server_args.host
50
57
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
51
58
  self.dist_init_addr = server_args.dist_init_addr
52
- self.tp_size = server_args.tp_size
53
- self.dp_size = server_args.dp_size
54
- self.enable_dp_attention = server_args.enable_dp_attention
55
- if not server_args.enable_dp_attention and server_args.dp_size != 1:
56
- raise ValueError(
57
- "If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
58
- )
59
-
59
+ self.attn_tp_size = get_attention_tp_size()
60
+ self.attn_tp_rank = get_attention_tp_rank()
61
+ self.attn_dp_size = get_attention_dp_size()
62
+ self.attn_dp_rank = get_attention_dp_rank()
63
+ self.system_dp_size = (
64
+ 1 if server_args.enable_dp_attention else server_args.dp_size
65
+ )
66
+ self.system_dp_rank = (
67
+ self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
68
+ )
69
+ self.pp_size = server_args.pp_size
70
+ self.pp_rank = self.kv_args.pp_rank
60
71
  self.rank_port = get_free_port()
72
+ self.local_ip = get_local_ip_auto()
73
+ self.server_socket = zmq.Context().socket(zmq.PULL)
74
+ if is_valid_ipv6_address(self.local_ip):
75
+ self.server_socket.setsockopt(zmq.IPV6, 1)
76
+ self.request_status: Dict[int, KVPoll] = {}
77
+
61
78
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
62
79
  self._register_to_bootstrap()
80
+ self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
81
+ self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
82
+ self.pp_group = get_pp_group()
63
83
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
64
84
  self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
65
- self.prefill_tp_size_table: Dict[str, int] = {}
85
+ self.connection_lock = threading.Lock()
86
+ self.required_prefill_response_num_table: Dict[int, int] = {}
87
+ self.prefill_attn_tp_size_table: Dict[str, int] = {}
66
88
  self.prefill_dp_size_table: Dict[str, int] = {}
89
+ self.prefill_pp_size_table: Dict[str, int] = {}
67
90
  else:
68
91
  raise ValueError(
69
92
  f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
70
93
  )
71
94
 
95
+ def _bind_server_socket(self):
96
+ self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
97
+
72
98
  def _register_to_bootstrap(self):
73
99
  """Register KVSender to bootstrap server via HTTP POST."""
74
100
  if self.dist_init_addr:
101
+ # Multi-node case: bootstrap server's host is dist_init_addr
75
102
  if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
76
103
  if self.dist_init_addr.endswith("]"):
77
104
  host = self.dist_init_addr
@@ -80,30 +107,38 @@ class CommonKVManager(BaseKVManager):
80
107
  else:
81
108
  host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
82
109
  else:
83
- host = get_ip()
110
+ # Single-node case: bootstrap server's host is the same as http server's host
111
+ host = self.bootstrap_host
84
112
  host = maybe_wrap_ipv6_address(host)
85
113
 
86
114
  bootstrap_server_url = f"{host}:{self.bootstrap_port}"
87
115
  url = f"http://{bootstrap_server_url}/route"
88
116
  payload = {
89
117
  "role": "Prefill",
90
- "tp_size": self.tp_size,
91
- "dp_size": self.dp_size,
92
- "rank_ip": get_local_ip_by_remote(),
118
+ "attn_tp_size": self.attn_tp_size,
119
+ "attn_tp_rank": self.attn_tp_rank,
120
+ "attn_dp_size": self.attn_dp_size,
121
+ "attn_dp_rank": self.attn_dp_rank,
122
+ "pp_size": self.pp_size,
123
+ "pp_rank": self.pp_rank,
124
+ "system_dp_size": self.system_dp_size,
125
+ "system_dp_rank": self.system_dp_rank,
126
+ "rank_ip": self.local_ip,
93
127
  "rank_port": self.rank_port,
94
- "engine_rank": self.kv_args.engine_rank,
95
128
  }
96
129
 
97
130
  try:
98
- response = requests.put(url, json=payload)
131
+ response = requests.put(url, json=payload, timeout=5)
99
132
  if response.status_code == 200:
100
133
  logger.debug("Prefill successfully registered to bootstrap server.")
101
134
  else:
102
135
  logger.error(
103
- f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
136
+ f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
104
137
  )
105
138
  except Exception as e:
106
- logger.error(f"Prefill Failed to register to bootstrap server: {e}")
139
+ logger.error(
140
+ f"Prefill instance failed to register to bootstrap server: {e}"
141
+ )
107
142
 
108
143
  @cache
109
144
  def _connect(self, endpoint: str, is_ipv6: bool = False):
@@ -113,6 +148,68 @@ class CommonKVManager(BaseKVManager):
113
148
  socket.connect(endpoint)
114
149
  return socket
115
150
 
151
+ def get_mha_kv_ptrs_with_pp(
152
+ self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
153
+ ) -> Tuple[List[int], List[int], List[int], List[int], int]:
154
+ # pp is not supported on the decode side yet
155
+ start_layer = self.kv_args.prefill_start_layer
156
+ num_kv_layers = len(src_kv_ptrs) // 2
157
+ end_layer = start_layer + num_kv_layers
158
+ dst_num_total_layers = len(dst_kv_ptrs) // 2
159
+ src_k_ptrs = src_kv_ptrs[:num_kv_layers]
160
+ src_v_ptrs = src_kv_ptrs[num_kv_layers:]
161
+ dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
162
+ dst_v_ptrs = dst_kv_ptrs[
163
+ dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
164
+ ]
165
+ layers_current_pp_stage = len(src_k_ptrs)
166
+ return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
167
+
168
+ def get_mla_kv_ptrs_with_pp(
169
+ self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
170
+ ) -> Tuple[List[int], List[int], int]:
171
+ # pp is not supported on the decode side yet
172
+ start_layer = self.kv_args.prefill_start_layer
173
+ end_layer = start_layer + len(src_kv_ptrs)
174
+ sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
175
+ layers_current_pp_stage = len(src_kv_ptrs)
176
+ return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
177
+
178
+
179
+ class CommonKVSender(BaseKVSender):
180
+
181
+ def __init__(
182
+ self,
183
+ mgr: BaseKVManager,
184
+ bootstrap_addr: str,
185
+ bootstrap_room: int,
186
+ dest_tp_ranks: List[int],
187
+ pp_rank: int,
188
+ ):
189
+ self.kv_mgr = mgr
190
+ self.bootstrap_room = bootstrap_room
191
+ self.aux_index = None
192
+ self.bootstrap_server_url = bootstrap_addr
193
+ # inner state
194
+ self.curr_idx = 0
195
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
196
+
197
+ def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
198
+ self.num_kv_indices = num_kv_indices
199
+ self.aux_index = aux_index
200
+
201
+ def send(
202
+ self,
203
+ kv_indices: npt.NDArray[np.int32],
204
+ ):
205
+ pass
206
+
207
+ def poll(self) -> KVPoll:
208
+ pass
209
+
210
+ def failure_exception(self):
211
+ raise Exception("Fake KVReceiver Exception")
212
+
116
213
 
117
214
  class CommonKVReceiver(BaseKVReceiver):
118
215
  _ctx = zmq.Context()
@@ -125,70 +222,93 @@ class CommonKVReceiver(BaseKVReceiver):
125
222
  mgr: BaseKVManager,
126
223
  bootstrap_addr: str,
127
224
  bootstrap_room: Optional[int] = None,
128
- data_parallel_rank: Optional[int] = None,
225
+ prefill_dp_rank: Optional[int] = None,
129
226
  ):
130
227
  self.bootstrap_room = bootstrap_room
131
228
  self.bootstrap_addr = bootstrap_addr
132
229
  self.kv_mgr = mgr
133
- self.data_parallel_rank = data_parallel_rank
230
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
134
231
 
135
232
  if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
136
- self.prefill_tp_size, self.prefill_dp_size = (
137
- self._get_prefill_dp_size_from_server()
138
- )
139
- if self.prefill_tp_size is None or self.prefill_dp_size is None:
140
- logger.error(
141
- f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
233
+ (
234
+ self.prefill_attn_tp_size,
235
+ self.prefill_dp_size,
236
+ self.prefill_pp_size,
237
+ ) = self._get_prefill_parallel_info_from_server()
238
+ if (
239
+ self.prefill_attn_tp_size is None
240
+ or self.prefill_dp_size is None
241
+ or self.prefill_pp_size is None
242
+ ):
243
+ self.kv_mgr.record_failure(
244
+ self.bootstrap_room,
245
+ f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
142
246
  )
247
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
248
+ return
143
249
  else:
144
- self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
145
- self.prefill_tp_size
250
+ logger.debug(
251
+ f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
252
+ )
253
+ self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
254
+ self.prefill_attn_tp_size
146
255
  )
147
256
  self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
148
257
  self.prefill_dp_size
149
258
  )
259
+ self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
260
+ self.prefill_pp_size
261
+ )
150
262
  else:
151
- self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
263
+ self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
152
264
  self.bootstrap_addr
153
265
  ]
154
266
  self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
155
267
  self.bootstrap_addr
156
268
  ]
269
+ self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
270
+ self.bootstrap_addr
271
+ ]
157
272
 
158
273
  # Currently, we don't allow prefill instance and decode instance to
159
274
  # have different TP sizes per DP rank, except for models using MLA.
160
- local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
161
- prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
162
- if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
275
+ if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
163
276
  self.target_tp_rank = (
164
- self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
277
+ self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
165
278
  )
166
279
  self.required_dst_info_num = 1
280
+ self.required_prefill_response_num = 1 * (
281
+ self.prefill_pp_size // self.kv_mgr.pp_size
282
+ )
167
283
  self.target_tp_ranks = [self.target_tp_rank]
168
- elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
169
- assert (
170
- self.kv_mgr.is_mla_backend
171
- ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
284
+ elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
285
+ if not self.kv_mgr.is_mla_backend:
286
+ logger.warning_once(
287
+ "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
288
+ )
172
289
  self.target_tp_rank = (
173
- self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
174
- ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
290
+ self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
291
+ ) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
175
292
  self.required_dst_info_num = (
176
- local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
293
+ self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
294
+ )
295
+ self.required_prefill_response_num = 1 * (
296
+ self.prefill_pp_size // self.kv_mgr.pp_size
177
297
  )
178
298
  self.target_tp_ranks = [self.target_tp_rank]
179
299
  else:
180
- assert (
181
- self.kv_mgr.is_mla_backend
182
- ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
183
-
300
+ if not self.kv_mgr.is_mla_backend:
301
+ logger.warning_once(
302
+ "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
303
+ )
184
304
  # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
185
305
  self.target_tp_ranks = [
186
306
  rank
187
307
  for rank in range(
188
- (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
189
- * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
190
- (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
191
- * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
308
+ (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
309
+ * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
310
+ (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
311
+ * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
192
312
  )
193
313
  ]
194
314
 
@@ -197,13 +317,27 @@ class CommonKVReceiver(BaseKVReceiver):
197
317
  # or the KVPoll will never be set correctly
198
318
  self.target_tp_rank = self.target_tp_ranks[0]
199
319
  self.required_dst_info_num = 1
320
+ if self.kv_mgr.is_mla_backend:
321
+ self.required_prefill_response_num = (
322
+ self.prefill_pp_size // self.kv_mgr.pp_size
323
+ )
324
+ else:
325
+ self.required_prefill_response_num = (
326
+ self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
327
+ ) * (self.prefill_pp_size // self.kv_mgr.pp_size)
200
328
 
201
- if self.data_parallel_rank is not None:
202
- logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
203
- self.target_dp_group = self.data_parallel_rank
329
+ if prefill_dp_rank is not None:
330
+ logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
331
+ self.prefill_dp_rank = prefill_dp_rank
204
332
  else:
205
- self.target_dp_group = bootstrap_room % self.prefill_dp_size
333
+ self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
334
+
335
+ # FIXME: alias here: target_dp_group -> prefill_dp_rank
336
+ self.target_dp_group = self.prefill_dp_rank
206
337
 
338
+ self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
339
+ self.required_prefill_response_num
340
+ )
207
341
  # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
208
342
  bootstrap_key = (
209
343
  f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
@@ -212,41 +346,49 @@ class CommonKVReceiver(BaseKVReceiver):
212
346
  if bootstrap_key not in self.kv_mgr.connection_pool:
213
347
  bootstrap_infos = []
214
348
  for target_tp_rank in self.target_tp_ranks:
215
- bootstrap_info = self._get_bootstrap_info_from_server(
216
- target_tp_rank,
217
- self.target_dp_group,
218
- )
219
- if bootstrap_info is not None:
220
- # NOTE: only support MLA for now: select one prefill rank as real rank
221
- bootstrap_info["is_dummy"] = not bool(
222
- target_tp_rank == self.target_tp_rank
223
- or self.target_tp_rank is None
224
- )
225
- bootstrap_infos.append(bootstrap_info)
226
- else:
227
- logger.error(
228
- f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
349
+ for target_pp_rank in range(self.prefill_pp_size):
350
+ bootstrap_info = self._get_bootstrap_info_from_server(
351
+ target_tp_rank, self.target_dp_group, target_pp_rank
229
352
  )
353
+ if bootstrap_info is not None:
354
+ if self.kv_mgr.is_mla_backend:
355
+ # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
356
+ bootstrap_info["is_dummy"] = not bool(
357
+ target_tp_rank == self.target_tp_rank
358
+ or self.target_tp_rank is None
359
+ )
360
+ else:
361
+ # For non-MLA: all target_tp_ranks are selected real ranks
362
+ bootstrap_info["is_dummy"] = False
363
+ logger.debug(
364
+ f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
365
+ )
366
+ bootstrap_infos.append(bootstrap_info)
367
+ else:
368
+ self.kv_mgr.record_failure(
369
+ self.bootstrap_room,
370
+ f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
371
+ )
372
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
373
+ return
374
+
230
375
  self.bootstrap_infos = bootstrap_infos
376
+ self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
231
377
 
232
- if len(self.bootstrap_infos) == 0:
233
- logger.error(
234
- f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
235
- )
236
- else:
237
- self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
238
- # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
239
- self._register_kv_args()
378
+ # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
379
+ self._register_kv_args()
240
380
  else:
241
381
  self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
242
382
 
243
383
  assert len(self.bootstrap_infos) > 0
244
384
 
245
- def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
385
+ def _get_bootstrap_info_from_server(
386
+ self, engine_rank, target_dp_group, target_pp_rank
387
+ ):
246
388
  """Fetch the bootstrap info from the bootstrap server."""
247
389
  try:
248
- url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
249
- response = requests.get(url)
390
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}"
391
+ response = requests.get(url, timeout=5)
250
392
  if response.status_code == 200:
251
393
  bootstrap_info = response.json()
252
394
  return bootstrap_info
@@ -259,24 +401,28 @@ class CommonKVReceiver(BaseKVReceiver):
259
401
  logger.error(f"Error fetching prefill info from bootstrap: {e}")
260
402
  return None
261
403
 
262
- def _get_prefill_dp_size_from_server(self) -> int:
404
+ def _get_prefill_parallel_info_from_server(
405
+ self,
406
+ ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
263
407
  """Fetch the prefill parallel info from the bootstrap server."""
264
408
  try:
265
- url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
409
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
266
410
  response = requests.get(url)
267
411
  if response.status_code == 200:
268
412
  prefill_parallel_info = response.json()
269
- return int(prefill_parallel_info["prefill_tp_size"]), int(
270
- prefill_parallel_info["prefill_dp_size"]
413
+ return (
414
+ int(prefill_parallel_info["prefill_attn_tp_size"]),
415
+ int(prefill_parallel_info["prefill_dp_size"]),
416
+ int(prefill_parallel_info["prefill_pp_size"]),
271
417
  )
272
418
  else:
273
419
  logger.error(
274
420
  f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
275
421
  )
276
- return None
422
+ return None, None, None
277
423
  except Exception as e:
278
424
  logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
279
- return None
425
+ return None, None, None
280
426
 
281
427
  @classmethod
282
428
  def _connect(cls, endpoint: str, is_ipv6: bool = False):
@@ -308,16 +454,19 @@ class CommonKVReceiver(BaseKVReceiver):
308
454
 
309
455
 
310
456
  class CommonKVBootstrapServer(BaseKVBootstrapServer):
311
- def __init__(self, port: int):
457
+ def __init__(self, host: str, port: int):
458
+ self.host = host
312
459
  self.port = port
313
460
  self.app = web.Application()
314
461
  self.store = dict()
315
462
  self.lock = asyncio.Lock()
316
463
  self._setup_routes()
317
- self.tp_size = None
464
+ self.pp_size = None
465
+ self.attn_tp_size = None
318
466
  self.dp_size = None
319
- self.tp_size_per_dp_rank = None
320
- self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
467
+ self.prefill_port_table: Dict[
468
+ int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
469
+ ] = {}
321
470
 
322
471
  # Start bootstrap server
323
472
  self.thread = threading.Thread(target=self._run_server, daemon=True)
@@ -328,6 +477,10 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
328
477
 
329
478
  def _setup_routes(self):
330
479
  self.app.router.add_route("*", "/route", self._handle_route)
480
+ self.app.router.add_get("/health", self._handle_health_check)
481
+
482
+ async def _handle_health_check(self, request):
483
+ return web.Response(text="OK", status=200)
331
484
 
332
485
  async def _handle_route(self, request: web.Request):
333
486
  method = request.method
@@ -343,37 +496,45 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
343
496
  async def _handle_route_put(self, request: web.Request):
344
497
  data = await request.json()
345
498
  role = data["role"]
346
- tp_size = data["tp_size"]
347
- dp_size = data["dp_size"]
499
+ attn_tp_size = data["attn_tp_size"]
500
+ attn_tp_rank = data["attn_tp_rank"]
501
+ attn_dp_size = data["attn_dp_size"]
502
+ attn_dp_rank = data["attn_dp_rank"]
503
+ pp_size = data["pp_size"]
504
+ pp_rank = data["pp_rank"]
505
+ system_dp_size = data["system_dp_size"]
506
+ system_dp_rank = data["system_dp_rank"]
348
507
  rank_ip = data["rank_ip"]
349
508
  rank_port = int(data["rank_port"])
350
- engine_rank = int(data["engine_rank"])
351
509
 
352
- if self.tp_size is None:
353
- self.tp_size = tp_size
510
+ if self.attn_tp_size is None:
511
+ self.attn_tp_size = attn_tp_size
354
512
 
355
513
  if self.dp_size is None:
356
- self.dp_size = dp_size
514
+ self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
357
515
 
358
- tp_size_per_dp_rank = tp_size // dp_size
359
- if self.tp_size_per_dp_rank == None:
360
- self.tp_size_per_dp_rank = tp_size_per_dp_rank
516
+ if self.pp_size is None:
517
+ self.pp_size = pp_size
361
518
 
362
- # Add lock to make sure thread-safe
363
519
  if role == "Prefill":
364
- dp_group = engine_rank // tp_size_per_dp_rank
365
- tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
520
+ if system_dp_size == 1:
521
+ dp_group = attn_dp_rank
522
+ else:
523
+ dp_group = system_dp_rank
366
524
 
525
+ # Add lock to make sure thread-safe
367
526
  async with self.lock:
368
527
  if dp_group not in self.prefill_port_table:
369
528
  self.prefill_port_table[dp_group] = {}
529
+ if attn_tp_rank not in self.prefill_port_table[dp_group]:
530
+ self.prefill_port_table[dp_group][attn_tp_rank] = {}
370
531
 
371
- self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
532
+ self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
372
533
  "rank_ip": rank_ip,
373
534
  "rank_port": rank_port,
374
535
  }
375
536
  logger.debug(
376
- f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
537
+ f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
377
538
  )
378
539
 
379
540
  return web.Response(text="OK", status=200)
@@ -381,14 +542,20 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
381
542
  async def _handle_route_get(self, request: web.Request):
382
543
  engine_rank = request.query.get("engine_rank")
383
544
  target_dp_group = request.query.get("target_dp_group")
384
- if not engine_rank or not target_dp_group:
545
+ target_pp_rank = request.query.get("target_pp_rank")
546
+ if not engine_rank or not target_dp_group or not target_pp_rank:
385
547
  return web.Response(text="Missing inputs for bootstrap server.", status=400)
386
548
 
387
549
  # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
388
- if int(engine_rank) == -1 and int(target_dp_group) == -1:
550
+ if (
551
+ int(engine_rank) == -1
552
+ and int(target_dp_group) == -1
553
+ and int(target_pp_rank) == -1
554
+ ):
389
555
  prefill_parallel_info = {
390
- "prefill_tp_size": self.tp_size,
556
+ "prefill_attn_tp_size": self.attn_tp_size,
391
557
  "prefill_dp_size": self.dp_size,
558
+ "prefill_pp_size": self.pp_size,
392
559
  }
393
560
  return web.json_response(prefill_parallel_info, status=200)
394
561
 
@@ -396,7 +563,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
396
563
  async with self.lock:
397
564
  bootstrap_info = self.prefill_port_table[int(target_dp_group)][
398
565
  int(engine_rank)
399
- ]
566
+ ][int(target_pp_rank)]
400
567
 
401
568
  if bootstrap_info is not None:
402
569
  return web.json_response(bootstrap_info, status=200)
@@ -409,10 +576,14 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
409
576
  self._loop = asyncio.new_event_loop()
410
577
  asyncio.set_event_loop(self._loop)
411
578
 
412
- self._runner = web.AppRunner(self.app)
579
+ access_log = None
580
+ if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
581
+ access_log = self.app.logger
582
+
583
+ self._runner = web.AppRunner(self.app, access_log=access_log)
413
584
  self._loop.run_until_complete(self._runner.setup())
414
585
 
415
- site = web.TCPSite(self._runner, port=self.port)
586
+ site = web.TCPSite(self._runner, host=self.host, port=self.port)
416
587
  self._loop.run_until_complete(site.start())
417
588
  self._loop.run_forever()
418
589
  except Exception as e: