sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,18 @@ from sglang.srt.disaggregation.base.conn import (
22
22
  KVPoll,
23
23
  )
24
24
  from sglang.srt.disaggregation.utils import DisaggregationMode
25
+ from sglang.srt.distributed import get_pp_group
26
+ from sglang.srt.layers.dp_attention import (
27
+ get_attention_dp_rank,
28
+ get_attention_dp_size,
29
+ get_attention_tp_rank,
30
+ get_attention_tp_size,
31
+ )
25
32
  from sglang.srt.server_args import ServerArgs
26
33
  from sglang.srt.utils import (
27
34
  format_tcp_address,
28
35
  get_free_port,
29
- get_ip,
30
- get_local_ip_by_remote,
36
+ get_local_ip_auto,
31
37
  is_valid_ipv6_address,
32
38
  maybe_wrap_ipv6_address,
33
39
  )
@@ -50,30 +56,49 @@ class CommonKVManager(BaseKVManager):
50
56
  self.bootstrap_host = server_args.host
51
57
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
52
58
  self.dist_init_addr = server_args.dist_init_addr
53
- self.tp_size = server_args.tp_size
54
- self.dp_size = server_args.dp_size
55
- self.enable_dp_attention = server_args.enable_dp_attention
56
- if not server_args.enable_dp_attention and server_args.dp_size != 1:
57
- raise ValueError(
58
- "If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
59
- )
60
-
59
+ self.attn_tp_size = get_attention_tp_size()
60
+ self.attn_tp_rank = get_attention_tp_rank()
61
+ self.attn_dp_size = get_attention_dp_size()
62
+ self.attn_dp_rank = get_attention_dp_rank()
63
+ self.system_dp_size = (
64
+ 1 if server_args.enable_dp_attention else server_args.dp_size
65
+ )
66
+ self.system_dp_rank = (
67
+ self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
68
+ )
69
+ self.pp_size = server_args.pp_size
70
+ self.pp_rank = self.kv_args.pp_rank
61
71
  self.rank_port = get_free_port()
72
+ self.local_ip = get_local_ip_auto()
73
+ self.server_socket = zmq.Context().socket(zmq.PULL)
74
+ if is_valid_ipv6_address(self.local_ip):
75
+ self.server_socket.setsockopt(zmq.IPV6, 1)
76
+ self.request_status: Dict[int, KVPoll] = {}
77
+
62
78
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
63
79
  self._register_to_bootstrap()
80
+ self.transfer_infos = {}
81
+ self.decode_kv_args_table = {}
82
+ self.pp_group = get_pp_group()
64
83
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
65
84
  self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
66
- self.prefill_tp_size_table: Dict[str, int] = {}
85
+ self.connection_lock = threading.Lock()
86
+ self.required_prefill_response_num_table: Dict[int, int] = {}
87
+ self.prefill_attn_tp_size_table: Dict[str, int] = {}
67
88
  self.prefill_dp_size_table: Dict[str, int] = {}
89
+ self.prefill_pp_size_table: Dict[str, int] = {}
68
90
  else:
69
91
  raise ValueError(
70
92
  f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
71
93
  )
72
94
 
95
+ def _bind_server_socket(self):
96
+ self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
97
+
73
98
  def _register_to_bootstrap(self):
74
99
  """Register KVSender to bootstrap server via HTTP POST."""
75
100
  if self.dist_init_addr:
76
- # multi node: bootstrap server's host is dist_init_addr
101
+ # Multi-node case: bootstrap server's host is dist_init_addr
77
102
  if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
78
103
  if self.dist_init_addr.endswith("]"):
79
104
  host = self.dist_init_addr
@@ -82,7 +107,7 @@ class CommonKVManager(BaseKVManager):
82
107
  else:
83
108
  host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
84
109
  else:
85
- # single node: bootstrap server's host is same as http server's host
110
+ # Single-node case: bootstrap server's host is the same as http server's host
86
111
  host = self.bootstrap_host
87
112
  host = maybe_wrap_ipv6_address(host)
88
113
 
@@ -90,23 +115,30 @@ class CommonKVManager(BaseKVManager):
90
115
  url = f"http://{bootstrap_server_url}/route"
91
116
  payload = {
92
117
  "role": "Prefill",
93
- "tp_size": self.tp_size,
94
- "dp_size": self.dp_size,
95
- "rank_ip": get_local_ip_by_remote(),
118
+ "attn_tp_size": self.attn_tp_size,
119
+ "attn_tp_rank": self.attn_tp_rank,
120
+ "attn_dp_size": self.attn_dp_size,
121
+ "attn_dp_rank": self.attn_dp_rank,
122
+ "pp_size": self.pp_size,
123
+ "pp_rank": self.pp_rank,
124
+ "system_dp_size": self.system_dp_size,
125
+ "system_dp_rank": self.system_dp_rank,
126
+ "rank_ip": self.local_ip,
96
127
  "rank_port": self.rank_port,
97
- "engine_rank": self.kv_args.engine_rank,
98
128
  }
99
129
 
100
130
  try:
101
- response = requests.put(url, json=payload)
131
+ response = requests.put(url, json=payload, timeout=5)
102
132
  if response.status_code == 200:
103
133
  logger.debug("Prefill successfully registered to bootstrap server.")
104
134
  else:
105
135
  logger.error(
106
- f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
136
+ f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
107
137
  )
108
138
  except Exception as e:
109
- logger.error(f"Prefill Failed to register to bootstrap server: {e}")
139
+ logger.error(
140
+ f"Prefill instance failed to register to bootstrap server: {e}"
141
+ )
110
142
 
111
143
  @cache
112
144
  def _connect(self, endpoint: str, is_ipv6: bool = False):
@@ -116,6 +148,69 @@ class CommonKVManager(BaseKVManager):
116
148
  socket.connect(endpoint)
117
149
  return socket
118
150
 
151
+ def get_mha_kv_ptrs_with_pp(
152
+ self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
153
+ ) -> Tuple[List[int], List[int], List[int], List[int], int]:
154
+ # pp is not supported on the decode side yet
155
+ start_layer = self.kv_args.prefill_start_layer
156
+ num_kv_layers = len(src_kv_ptrs) // 2
157
+ end_layer = start_layer + num_kv_layers
158
+ dst_num_total_layers = len(dst_kv_ptrs) // 2
159
+ src_k_ptrs = src_kv_ptrs[:num_kv_layers]
160
+ src_v_ptrs = src_kv_ptrs[num_kv_layers:]
161
+ dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
162
+ dst_v_ptrs = dst_kv_ptrs[
163
+ dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
164
+ ]
165
+ layers_current_pp_stage = len(src_k_ptrs)
166
+ return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
167
+
168
+ def get_mla_kv_ptrs_with_pp(
169
+ self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
170
+ ) -> Tuple[List[int], List[int], int]:
171
+ # pp is not supported on the decode side yet
172
+ start_layer = self.kv_args.prefill_start_layer
173
+ end_layer = start_layer + len(src_kv_ptrs)
174
+ sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
175
+ layers_current_pp_stage = len(src_kv_ptrs)
176
+ return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
177
+
178
+
179
+ class CommonKVSender(BaseKVSender):
180
+
181
+ def __init__(
182
+ self,
183
+ mgr: BaseKVManager,
184
+ bootstrap_addr: str,
185
+ bootstrap_room: int,
186
+ dest_tp_ranks: List[int],
187
+ pp_rank: int,
188
+ ):
189
+ self.kv_mgr = mgr
190
+ self.bootstrap_room = bootstrap_room
191
+ self.aux_index = None
192
+ self.bootstrap_server_url = bootstrap_addr
193
+ # inner state
194
+ self.curr_idx = 0
195
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
196
+
197
+ def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
198
+ self.num_kv_indices = num_kv_indices
199
+ self.aux_index = aux_index
200
+
201
+ def send(
202
+ self,
203
+ kv_indices: npt.NDArray[np.int32],
204
+ state_indices: Optional[List[int]] = None,
205
+ ):
206
+ pass
207
+
208
+ def poll(self) -> KVPoll:
209
+ pass
210
+
211
+ def failure_exception(self):
212
+ raise Exception("Fake KVReceiver Exception")
213
+
119
214
 
120
215
  class CommonKVReceiver(BaseKVReceiver):
121
216
  _ctx = zmq.Context()
@@ -133,61 +228,89 @@ class CommonKVReceiver(BaseKVReceiver):
133
228
  self.bootstrap_room = bootstrap_room
134
229
  self.bootstrap_addr = bootstrap_addr
135
230
  self.kv_mgr = mgr
231
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
136
232
 
137
233
  if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
138
- self.prefill_tp_size, self.prefill_dp_size = (
139
- self._get_prefill_dp_size_from_server()
140
- )
141
- if self.prefill_tp_size is None or self.prefill_dp_size is None:
142
- logger.error(
143
- f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
234
+ (
235
+ self.prefill_attn_tp_size,
236
+ self.prefill_dp_size,
237
+ self.prefill_pp_size,
238
+ ) = self._get_prefill_parallel_info_from_server()
239
+ if (
240
+ self.prefill_attn_tp_size is None
241
+ or self.prefill_dp_size is None
242
+ or self.prefill_pp_size is None
243
+ ):
244
+ self.kv_mgr.record_failure(
245
+ self.bootstrap_room,
246
+ f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
144
247
  )
248
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
249
+ self.bootstrap_infos = None
250
+ return
145
251
  else:
146
- self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
147
- self.prefill_tp_size
252
+ logger.debug(
253
+ f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
254
+ )
255
+ self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
256
+ self.prefill_attn_tp_size
148
257
  )
149
258
  self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
150
259
  self.prefill_dp_size
151
260
  )
261
+ self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
262
+ self.prefill_pp_size
263
+ )
152
264
  else:
153
- self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
265
+ self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
154
266
  self.bootstrap_addr
155
267
  ]
156
268
  self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
157
269
  self.bootstrap_addr
158
270
  ]
271
+ self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
272
+ self.bootstrap_addr
273
+ ]
159
274
 
160
275
  # Currently, we don't allow prefill instance and decode instance to
161
276
  # have different TP sizes per DP rank, except for models using MLA.
162
- local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
163
- prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
164
- if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
277
+ if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
165
278
  self.target_tp_rank = (
166
- self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
279
+ self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
167
280
  )
168
281
  self.required_dst_info_num = 1
282
+ self.required_prefill_response_num = 1 * (
283
+ self.prefill_pp_size // self.kv_mgr.pp_size
284
+ )
169
285
  self.target_tp_ranks = [self.target_tp_rank]
170
- elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
286
+ elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
287
+ if not self.kv_mgr.is_mla_backend:
288
+ logger.warning_once(
289
+ "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
290
+ )
171
291
  self.target_tp_rank = (
172
- self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
173
- ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
292
+ self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
293
+ ) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
174
294
  self.required_dst_info_num = (
175
- local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
295
+ self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
296
+ )
297
+ self.required_prefill_response_num = 1 * (
298
+ self.prefill_pp_size // self.kv_mgr.pp_size
176
299
  )
177
300
  self.target_tp_ranks = [self.target_tp_rank]
178
301
  else:
179
- assert (
180
- self.kv_mgr.is_mla_backend
181
- ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
182
-
302
+ if not self.kv_mgr.is_mla_backend:
303
+ logger.warning_once(
304
+ "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
305
+ )
183
306
  # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
184
307
  self.target_tp_ranks = [
185
308
  rank
186
309
  for rank in range(
187
- (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
188
- * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
189
- (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
190
- * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
310
+ (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
311
+ * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
312
+ (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
313
+ * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
191
314
  )
192
315
  ]
193
316
 
@@ -196,6 +319,14 @@ class CommonKVReceiver(BaseKVReceiver):
196
319
  # or the KVPoll will never be set correctly
197
320
  self.target_tp_rank = self.target_tp_ranks[0]
198
321
  self.required_dst_info_num = 1
322
+ if self.kv_mgr.is_mla_backend:
323
+ self.required_prefill_response_num = (
324
+ self.prefill_pp_size // self.kv_mgr.pp_size
325
+ )
326
+ else:
327
+ self.required_prefill_response_num = (
328
+ self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
329
+ ) * (self.prefill_pp_size // self.kv_mgr.pp_size)
199
330
 
200
331
  if prefill_dp_rank is not None:
201
332
  logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
@@ -206,6 +337,9 @@ class CommonKVReceiver(BaseKVReceiver):
206
337
  # FIXME: alias here: target_dp_group -> prefill_dp_rank
207
338
  self.target_dp_group = self.prefill_dp_rank
208
339
 
340
+ self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
341
+ self.required_prefill_response_num
342
+ )
209
343
  # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
210
344
  bootstrap_key = (
211
345
  f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
@@ -214,41 +348,49 @@ class CommonKVReceiver(BaseKVReceiver):
214
348
  if bootstrap_key not in self.kv_mgr.connection_pool:
215
349
  bootstrap_infos = []
216
350
  for target_tp_rank in self.target_tp_ranks:
217
- bootstrap_info = self._get_bootstrap_info_from_server(
218
- target_tp_rank,
219
- self.target_dp_group,
220
- )
221
- if bootstrap_info is not None:
222
- # NOTE: only support MLA for now: select one prefill rank as real rank
223
- bootstrap_info["is_dummy"] = not bool(
224
- target_tp_rank == self.target_tp_rank
225
- or self.target_tp_rank is None
226
- )
227
- bootstrap_infos.append(bootstrap_info)
228
- else:
229
- logger.error(
230
- f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
351
+ for target_pp_rank in range(self.prefill_pp_size):
352
+ bootstrap_info = self._get_bootstrap_info_from_server(
353
+ target_tp_rank, self.target_dp_group, target_pp_rank
231
354
  )
355
+ if bootstrap_info is not None:
356
+ if self.kv_mgr.is_mla_backend:
357
+ # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
358
+ bootstrap_info["is_dummy"] = not bool(
359
+ target_tp_rank == self.target_tp_rank
360
+ or self.target_tp_rank is None
361
+ )
362
+ else:
363
+ # For non-MLA: all target_tp_ranks are selected real ranks
364
+ bootstrap_info["is_dummy"] = False
365
+ logger.debug(
366
+ f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
367
+ )
368
+ bootstrap_infos.append(bootstrap_info)
369
+ else:
370
+ self.kv_mgr.record_failure(
371
+ self.bootstrap_room,
372
+ f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
373
+ )
374
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
375
+ return
376
+
232
377
  self.bootstrap_infos = bootstrap_infos
378
+ self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
233
379
 
234
- if len(self.bootstrap_infos) == 0:
235
- logger.error(
236
- f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
237
- )
238
- else:
239
- self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
240
- # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
241
- self._register_kv_args()
380
+ # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
381
+ self._register_kv_args()
242
382
  else:
243
383
  self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
244
384
 
245
385
  assert len(self.bootstrap_infos) > 0
246
386
 
247
- def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
387
+ def _get_bootstrap_info_from_server(
388
+ self, engine_rank, target_dp_group, target_pp_rank
389
+ ):
248
390
  """Fetch the bootstrap info from the bootstrap server."""
249
391
  try:
250
- url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
251
- response = requests.get(url)
392
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}"
393
+ response = requests.get(url, timeout=5)
252
394
  if response.status_code == 200:
253
395
  bootstrap_info = response.json()
254
396
  return bootstrap_info
@@ -261,24 +403,28 @@ class CommonKVReceiver(BaseKVReceiver):
261
403
  logger.error(f"Error fetching prefill info from bootstrap: {e}")
262
404
  return None
263
405
 
264
- def _get_prefill_dp_size_from_server(self) -> int:
406
+ def _get_prefill_parallel_info_from_server(
407
+ self,
408
+ ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
265
409
  """Fetch the prefill parallel info from the bootstrap server."""
266
410
  try:
267
- url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
411
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
268
412
  response = requests.get(url)
269
413
  if response.status_code == 200:
270
414
  prefill_parallel_info = response.json()
271
- return int(prefill_parallel_info["prefill_tp_size"]), int(
272
- prefill_parallel_info["prefill_dp_size"]
415
+ return (
416
+ int(prefill_parallel_info["prefill_attn_tp_size"]),
417
+ int(prefill_parallel_info["prefill_dp_size"]),
418
+ int(prefill_parallel_info["prefill_pp_size"]),
273
419
  )
274
420
  else:
275
421
  logger.error(
276
422
  f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
277
423
  )
278
- return None
424
+ return None, None, None
279
425
  except Exception as e:
280
426
  logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
281
- return None
427
+ return None, None, None
282
428
 
283
429
  @classmethod
284
430
  def _connect(cls, endpoint: str, is_ipv6: bool = False):
@@ -317,10 +463,12 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
317
463
  self.store = dict()
318
464
  self.lock = asyncio.Lock()
319
465
  self._setup_routes()
320
- self.tp_size = None
466
+ self.pp_size = None
467
+ self.attn_tp_size = None
321
468
  self.dp_size = None
322
- self.tp_size_per_dp_rank = None
323
- self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
469
+ self.prefill_port_table: Dict[
470
+ int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
471
+ ] = {}
324
472
 
325
473
  # Start bootstrap server
326
474
  self.thread = threading.Thread(target=self._run_server, daemon=True)
@@ -331,6 +479,10 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
331
479
 
332
480
  def _setup_routes(self):
333
481
  self.app.router.add_route("*", "/route", self._handle_route)
482
+ self.app.router.add_get("/health", self._handle_health_check)
483
+
484
+ async def _handle_health_check(self, request):
485
+ return web.Response(text="OK", status=200)
334
486
 
335
487
  async def _handle_route(self, request: web.Request):
336
488
  method = request.method
@@ -346,37 +498,45 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
346
498
  async def _handle_route_put(self, request: web.Request):
347
499
  data = await request.json()
348
500
  role = data["role"]
349
- tp_size = data["tp_size"]
350
- dp_size = data["dp_size"]
501
+ attn_tp_size = data["attn_tp_size"]
502
+ attn_tp_rank = data["attn_tp_rank"]
503
+ attn_dp_size = data["attn_dp_size"]
504
+ attn_dp_rank = data["attn_dp_rank"]
505
+ pp_size = data["pp_size"]
506
+ pp_rank = data["pp_rank"]
507
+ system_dp_size = data["system_dp_size"]
508
+ system_dp_rank = data["system_dp_rank"]
351
509
  rank_ip = data["rank_ip"]
352
510
  rank_port = int(data["rank_port"])
353
- engine_rank = int(data["engine_rank"])
354
511
 
355
- if self.tp_size is None:
356
- self.tp_size = tp_size
512
+ if self.attn_tp_size is None:
513
+ self.attn_tp_size = attn_tp_size
357
514
 
358
515
  if self.dp_size is None:
359
- self.dp_size = dp_size
516
+ self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
360
517
 
361
- tp_size_per_dp_rank = tp_size // dp_size
362
- if self.tp_size_per_dp_rank == None:
363
- self.tp_size_per_dp_rank = tp_size_per_dp_rank
518
+ if self.pp_size is None:
519
+ self.pp_size = pp_size
364
520
 
365
- # Add lock to make sure thread-safe
366
521
  if role == "Prefill":
367
- dp_group = engine_rank // tp_size_per_dp_rank
368
- tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
522
+ if system_dp_size == 1:
523
+ dp_group = attn_dp_rank
524
+ else:
525
+ dp_group = system_dp_rank
369
526
 
527
+ # Add lock to make sure thread-safe
370
528
  async with self.lock:
371
529
  if dp_group not in self.prefill_port_table:
372
530
  self.prefill_port_table[dp_group] = {}
531
+ if attn_tp_rank not in self.prefill_port_table[dp_group]:
532
+ self.prefill_port_table[dp_group][attn_tp_rank] = {}
373
533
 
374
- self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
534
+ self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
375
535
  "rank_ip": rank_ip,
376
536
  "rank_port": rank_port,
377
537
  }
378
538
  logger.debug(
379
- f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
539
+ f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
380
540
  )
381
541
 
382
542
  return web.Response(text="OK", status=200)
@@ -384,14 +544,20 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
384
544
  async def _handle_route_get(self, request: web.Request):
385
545
  engine_rank = request.query.get("engine_rank")
386
546
  target_dp_group = request.query.get("target_dp_group")
387
- if not engine_rank or not target_dp_group:
547
+ target_pp_rank = request.query.get("target_pp_rank")
548
+ if not engine_rank or not target_dp_group or not target_pp_rank:
388
549
  return web.Response(text="Missing inputs for bootstrap server.", status=400)
389
550
 
390
551
  # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
391
- if int(engine_rank) == -1 and int(target_dp_group) == -1:
552
+ if (
553
+ int(engine_rank) == -1
554
+ and int(target_dp_group) == -1
555
+ and int(target_pp_rank) == -1
556
+ ):
392
557
  prefill_parallel_info = {
393
- "prefill_tp_size": self.tp_size,
558
+ "prefill_attn_tp_size": self.attn_tp_size,
394
559
  "prefill_dp_size": self.dp_size,
560
+ "prefill_pp_size": self.pp_size,
395
561
  }
396
562
  return web.json_response(prefill_parallel_info, status=200)
397
563
 
@@ -399,7 +565,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
399
565
  async with self.lock:
400
566
  bootstrap_info = self.prefill_port_table[int(target_dp_group)][
401
567
  int(engine_rank)
402
- ]
568
+ ][int(target_pp_rank)]
403
569
 
404
570
  if bootstrap_info is not None:
405
571
  return web.json_response(bootstrap_info, status=200)
@@ -412,7 +578,11 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
412
578
  self._loop = asyncio.new_event_loop()
413
579
  asyncio.set_event_loop(self._loop)
414
580
 
415
- self._runner = web.AppRunner(self.app)
581
+ access_log = None
582
+ if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
583
+ access_log = self.app.logger
584
+
585
+ self._runner = web.AppRunner(self.app, access_log=access_log)
416
586
  self._loop.run_until_complete(self._runner.setup())
417
587
 
418
588
  site = web.TCPSite(self._runner, host=self.host, port=self.port)