sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
4
3
  import concurrent.futures
5
4
  import ctypes
6
5
  import dataclasses
7
6
  import logging
8
7
  import os
9
- import queue
10
- import socket
11
8
  import struct
12
9
  import threading
13
10
  import time
14
11
  from collections import defaultdict
15
- from functools import cache
16
- from typing import Dict, List, Optional, Tuple, Union
12
+ from typing import Dict, List, Optional, Tuple
17
13
 
18
14
  import numpy as np
19
15
  import numpy.typing as npt
20
16
  import requests
21
17
  import zmq
22
- from aiohttp import web
23
-
24
- from sglang.srt.disaggregation.base.conn import (
25
- BaseKVBootstrapServer,
26
- BaseKVManager,
27
- BaseKVReceiver,
28
- BaseKVSender,
29
- KVArgs,
30
- KVPoll,
18
+
19
+ from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
20
+ from sglang.srt.disaggregation.common.conn import (
21
+ CommonKVBootstrapServer,
22
+ CommonKVManager,
23
+ CommonKVReceiver,
24
+ CommonKVSender,
31
25
  )
32
26
  from sglang.srt.disaggregation.common.utils import (
33
27
  FastQueue,
@@ -35,23 +29,12 @@ from sglang.srt.disaggregation.common.utils import (
35
29
  )
36
30
  from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
37
31
  from sglang.srt.disaggregation.utils import DisaggregationMode
38
- from sglang.srt.distributed import get_pp_group
39
- from sglang.srt.layers.dp_attention import (
40
- get_attention_dp_rank,
41
- get_attention_dp_size,
42
- get_attention_tp_rank,
43
- get_attention_tp_size,
44
- )
45
32
  from sglang.srt.server_args import ServerArgs
46
33
  from sglang.srt.utils import (
47
34
  format_tcp_address,
48
35
  get_bool_env_var,
49
- get_free_port,
50
36
  get_int_env_var,
51
- get_ip,
52
- get_local_ip_auto,
53
37
  is_valid_ipv6_address,
54
- maybe_wrap_ipv6_address,
55
38
  )
56
39
 
57
40
  logger = logging.getLogger(__name__)
@@ -159,7 +142,7 @@ class AuxDataCodec:
159
142
  return
160
143
 
161
144
 
162
- class MooncakeKVManager(BaseKVManager):
145
+ class MooncakeKVManager(CommonKVManager):
163
146
  AUX_DATA_HEADER = b"AUX_DATA"
164
147
 
165
148
  def __init__(
@@ -169,42 +152,14 @@ class MooncakeKVManager(BaseKVManager):
169
152
  server_args: ServerArgs,
170
153
  is_mla_backend: Optional[bool] = False,
171
154
  ):
172
- self.kv_args = args
173
- self.local_ip = get_local_ip_auto()
174
- self.is_mla_backend = is_mla_backend
175
- self.disaggregation_mode = disaggregation_mode
155
+ super().__init__(args, disaggregation_mode, server_args, is_mla_backend)
176
156
  self.init_engine()
177
- # for p/d multi node infer
178
- self.bootstrap_port = server_args.disaggregation_bootstrap_port
179
- self.dist_init_addr = server_args.dist_init_addr
180
- self.attn_tp_size = get_attention_tp_size()
181
- self.attn_tp_rank = get_attention_tp_rank()
182
- self.attn_dp_size = get_attention_dp_size()
183
- self.attn_dp_rank = get_attention_dp_rank()
184
- self.system_dp_size = (
185
- 1 if server_args.enable_dp_attention else server_args.dp_size
186
- )
187
- self.system_dp_rank = (
188
- self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
189
- )
190
- self.pp_size = server_args.pp_size
191
- self.pp_rank = self.kv_args.pp_rank
192
- self.request_status: Dict[int, KVPoll] = {}
193
- self.rank_port = None
194
- self.server_socket = zmq.Context().socket(zmq.PULL)
195
- if is_valid_ipv6_address(self.local_ip):
196
- self.server_socket.setsockopt(zmq.IPV6, 1)
197
-
198
157
  self.register_buffer_to_engine()
199
158
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
200
- self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
201
- self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
202
159
  self.start_prefill_thread()
203
- self._register_to_bootstrap()
204
160
  self.session_failures = defaultdict(int)
205
161
  self.failed_sessions = set()
206
162
  self.session_lock = threading.Lock()
207
- self.pp_group = get_pp_group()
208
163
  # Determine the number of threads to use for kv sender
209
164
  cpu_count = os.cpu_count()
210
165
  transfer_thread_pool_size = get_int_env_var(
@@ -244,8 +199,6 @@ class MooncakeKVManager(BaseKVManager):
244
199
  self.session_pool = defaultdict(requests.Session)
245
200
  self.session_pool_lock = threading.Lock()
246
201
  self.addr_to_rooms_tracker = defaultdict(set)
247
- self.connection_lock = threading.Lock()
248
- self.required_prefill_response_num_table: Dict[int, int] = {}
249
202
  self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
250
203
  # Heartbeat interval should be at least 2 seconds
251
204
  self.heartbeat_interval = max(
@@ -256,20 +209,12 @@ class MooncakeKVManager(BaseKVManager):
256
209
  get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
257
210
  )
258
211
  self.start_decode_thread()
259
- self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
260
- self.prefill_attn_tp_size_table: Dict[str, int] = {}
261
- self.prefill_dp_size_table: Dict[str, int] = {}
262
- self.prefill_pp_size_table: Dict[str, int] = {}
263
212
  # If a timeout happens on the decode side, it means decode instances
264
213
  # fail to receive the KV Cache transfer done signal after bootstrapping.
265
214
  # These timeout requests should be aborted to release the tree cache.
266
215
  self.waiting_timeout = get_int_env_var(
267
216
  "SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300
268
217
  )
269
- else:
270
- raise ValueError(
271
- f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
272
- )
273
218
 
274
219
  self.failure_records: Dict[int, str] = {}
275
220
  self.failure_lock = threading.Lock()
@@ -294,14 +239,6 @@ class MooncakeKVManager(BaseKVManager):
294
239
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
295
240
  )
296
241
 
297
- @cache
298
- def _connect(self, endpoint: str, is_ipv6: bool = False):
299
- socket = zmq.Context().socket(zmq.PUSH)
300
- if is_ipv6:
301
- socket.setsockopt(zmq.IPV6, 1)
302
- socket.connect(endpoint)
303
- return socket
304
-
305
242
  def _transfer_data(self, mooncake_session_id, transfer_blocks):
306
243
  if not transfer_blocks:
307
244
  return 0
@@ -327,12 +264,10 @@ class MooncakeKVManager(BaseKVManager):
327
264
  layers_params = None
328
265
 
329
266
  # pp is not supported on the decode side yet
330
- start_layer = self.kv_args.prefill_start_layer
331
- end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
332
267
  if self.is_mla_backend:
333
- src_kv_ptrs = self.kv_args.kv_data_ptrs
334
- layers_per_pp_stage = len(src_kv_ptrs)
335
- dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
268
+ src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
269
+ self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
270
+ )
336
271
  kv_item_len = self.kv_args.kv_item_lens[0]
337
272
  layers_params = [
338
273
  (
@@ -340,18 +275,12 @@ class MooncakeKVManager(BaseKVManager):
340
275
  dst_kv_ptrs[layer_id],
341
276
  kv_item_len,
342
277
  )
343
- for layer_id in range(layers_per_pp_stage)
278
+ for layer_id in range(layers_current_pp_stage)
344
279
  ]
345
280
  else:
346
- num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
347
- dst_num_total_layers = num_kv_layers * self.pp_size
348
- src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
349
- src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
350
- layers_per_pp_stage = len(src_k_ptrs)
351
- dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
352
- dst_v_ptrs = dst_kv_ptrs[
353
- dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
354
- ]
281
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
282
+ self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
283
+ )
355
284
  kv_item_len = self.kv_args.kv_item_lens[0]
356
285
  layers_params = [
357
286
  (
@@ -359,14 +288,14 @@ class MooncakeKVManager(BaseKVManager):
359
288
  dst_k_ptrs[layer_id],
360
289
  kv_item_len,
361
290
  )
362
- for layer_id in range(layers_per_pp_stage)
291
+ for layer_id in range(layers_current_pp_stage)
363
292
  ] + [
364
293
  (
365
294
  src_v_ptrs[layer_id],
366
295
  dst_v_ptrs[layer_id],
367
296
  kv_item_len,
368
297
  )
369
- for layer_id in range(layers_per_pp_stage)
298
+ for layer_id in range(layers_current_pp_stage)
370
299
  ]
371
300
  assert layers_params is not None
372
301
 
@@ -458,22 +387,15 @@ class MooncakeKVManager(BaseKVManager):
458
387
  dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
459
388
  else:
460
389
  # Send KVCache from 1 prefill instance to multiple decode instances
461
- src_head_start_offset = dst_tp_rank_in_group * dst_heads_per_rank
390
+ src_head_start_offset = (
391
+ dst_tp_rank_in_group * dst_heads_per_rank
392
+ ) % src_heads_per_rank
462
393
  num_heads_to_send = dst_heads_per_rank
463
394
  dst_head_start_offset = 0
464
395
 
465
- # pp is not supported on the decode side yet
466
- num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
467
- dst_num_total_layers = num_kv_layers * self.pp_size
468
- src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
469
- src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
470
- layers_per_pp_stage = len(src_k_ptrs)
471
- start_layer = self.pp_rank * layers_per_pp_stage
472
- end_layer = start_layer + layers_per_pp_stage
473
- dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
474
- dst_v_ptrs = dst_kv_ptrs[
475
- dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
476
- ]
396
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
397
+ self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
398
+ )
477
399
 
478
400
  # Calculate precise byte offset and length for the sub-slice within the token
479
401
  src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
@@ -499,7 +421,7 @@ class MooncakeKVManager(BaseKVManager):
499
421
  dst_head_slice_offset,
500
422
  heads_bytes_per_token_to_send,
501
423
  )
502
- for layer_id in range(layers_per_pp_stage)
424
+ for layer_id in range(layers_current_pp_stage)
503
425
  ] + [
504
426
  (
505
427
  src_v_ptrs[layer_id],
@@ -510,7 +432,7 @@ class MooncakeKVManager(BaseKVManager):
510
432
  dst_head_slice_offset,
511
433
  heads_bytes_per_token_to_send,
512
434
  )
513
- for layer_id in range(layers_per_pp_stage)
435
+ for layer_id in range(layers_current_pp_stage)
514
436
  ]
515
437
 
516
438
  def process_layer_tp_aware(layer_params):
@@ -651,6 +573,26 @@ class MooncakeKVManager(BaseKVManager):
651
573
  ]
652
574
  )
653
575
 
576
+ def _handle_aux_data(self, msg: List[bytes]):
577
+ """Handle AUX_DATA messages received by the decode thread."""
578
+ room = int(msg[1].decode("ascii"))
579
+ buffer_index = int(msg[2].decode("ascii"))
580
+ aux_index = int(msg[3].decode("ascii"))
581
+ data_length = struct.unpack(">I", msg[4])[0]
582
+ data = msg[5]
583
+
584
+ if len(data) != data_length:
585
+ logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
586
+ return
587
+
588
+ AuxDataCodec.deserialize_data_to_buffer(
589
+ self.kv_args, buffer_index, aux_index, data
590
+ )
591
+
592
+ logger.debug(
593
+ f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
594
+ )
595
+
654
596
  def sync_status_to_decode_endpoint(
655
597
  self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
656
598
  ):
@@ -799,11 +741,7 @@ class MooncakeKVManager(BaseKVManager):
799
741
  f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
800
742
  )
801
743
 
802
- def _bind_server_socket(self):
803
- self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
804
-
805
744
  def start_prefill_thread(self):
806
- self.rank_port = get_free_port()
807
745
  self._bind_server_socket()
808
746
 
809
747
  def bootstrap_thread():
@@ -841,28 +779,7 @@ class MooncakeKVManager(BaseKVManager):
841
779
 
842
780
  threading.Thread(target=bootstrap_thread).start()
843
781
 
844
- def _handle_aux_data(self, msg: List[bytes]):
845
- """Handle AUX_DATA messages received by the decode thread."""
846
- room = int(msg[1].decode("ascii"))
847
- buffer_index = int(msg[2].decode("ascii"))
848
- aux_index = int(msg[3].decode("ascii"))
849
- data_length = struct.unpack(">I", msg[4])[0]
850
- data = msg[5]
851
-
852
- if len(data) != data_length:
853
- logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
854
- return
855
-
856
- AuxDataCodec.deserialize_data_to_buffer(
857
- self.kv_args, buffer_index, aux_index, data
858
- )
859
-
860
- logger.debug(
861
- f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
862
- )
863
-
864
782
  def start_decode_thread(self):
865
- self.rank_port = get_free_port()
866
783
  self._bind_server_socket()
867
784
 
868
785
  def decode_thread():
@@ -1017,49 +934,6 @@ class MooncakeKVManager(BaseKVManager):
1017
934
  def get_session_id(self):
1018
935
  return self.engine.get_session_id()
1019
936
 
1020
- def _register_to_bootstrap(self):
1021
- """Register KVSender to bootstrap server via HTTP POST."""
1022
- if self.dist_init_addr:
1023
- if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
1024
- if self.dist_init_addr.endswith("]"):
1025
- host = self.dist_init_addr
1026
- else:
1027
- host, _ = self.dist_init_addr.rsplit(":", 1)
1028
- else:
1029
- host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
1030
- else:
1031
- host = get_ip()
1032
- host = maybe_wrap_ipv6_address(host)
1033
-
1034
- bootstrap_server_url = f"{host}:{self.bootstrap_port}"
1035
- url = f"http://{bootstrap_server_url}/route"
1036
- payload = {
1037
- "role": "Prefill",
1038
- "attn_tp_size": self.attn_tp_size,
1039
- "attn_tp_rank": self.attn_tp_rank,
1040
- "attn_dp_size": self.attn_dp_size,
1041
- "attn_dp_rank": self.attn_dp_rank,
1042
- "pp_size": self.pp_size,
1043
- "pp_rank": self.pp_rank,
1044
- "system_dp_size": self.system_dp_size,
1045
- "system_dp_rank": self.system_dp_rank,
1046
- "rank_ip": self.local_ip,
1047
- "rank_port": self.rank_port,
1048
- }
1049
-
1050
- try:
1051
- response = requests.put(url, json=payload, timeout=5)
1052
- if response.status_code == 200:
1053
- logger.debug("Prefill successfully registered to bootstrap server.")
1054
- else:
1055
- logger.error(
1056
- f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
1057
- )
1058
- except Exception as e:
1059
- logger.error(
1060
- f"Prefill instance failed to register to bootstrap server: {e}"
1061
- )
1062
-
1063
937
  def _handle_node_failure(self, failed_bootstrap_addr):
1064
938
  with self.connection_lock:
1065
939
  keys_to_remove = [
@@ -1098,7 +972,7 @@ class MooncakeKVManager(BaseKVManager):
1098
972
  )
1099
973
 
1100
974
 
1101
- class MooncakeKVSender(BaseKVSender):
975
+ class MooncakeKVSender(CommonKVSender):
1102
976
 
1103
977
  def __init__(
1104
978
  self,
@@ -1108,19 +982,9 @@ class MooncakeKVSender(BaseKVSender):
1108
982
  dest_tp_ranks: List[int],
1109
983
  pp_rank: int,
1110
984
  ):
1111
- self.kv_mgr = mgr
1112
- self.bootstrap_room = bootstrap_room
1113
- self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
1114
- self.aux_index = None
1115
- self.bootstrap_server_url = bootstrap_addr
985
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
1116
986
  self.conclude_state = None
1117
987
  self.init_time = time.time()
1118
- # inner state
1119
- self.curr_idx = 0
1120
-
1121
- def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
1122
- self.num_kv_indices = num_kv_indices
1123
- self.aux_index = aux_index
1124
988
 
1125
989
  def send(
1126
990
  self,
@@ -1198,7 +1062,7 @@ class MooncakeKVSender(BaseKVSender):
1198
1062
  self.conclude_state = KVPoll.Failed
1199
1063
 
1200
1064
 
1201
- class MooncakeKVReceiver(BaseKVReceiver):
1065
+ class MooncakeKVReceiver(CommonKVReceiver):
1202
1066
  _ctx = zmq.Context()
1203
1067
  _socket_cache = {}
1204
1068
  _socket_locks = {}
@@ -1209,166 +1073,13 @@ class MooncakeKVReceiver(BaseKVReceiver):
1209
1073
  mgr: MooncakeKVManager,
1210
1074
  bootstrap_addr: str,
1211
1075
  bootstrap_room: Optional[int] = None,
1212
- data_parallel_rank: Optional[int] = None,
1076
+ prefill_dp_rank: Optional[int] = None,
1213
1077
  ):
1214
- self.bootstrap_room = bootstrap_room
1215
- self.bootstrap_addr = bootstrap_addr
1216
- self.kv_mgr = mgr
1217
- self.session_id = self.kv_mgr.get_session_id()
1218
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
1078
+ self.session_id = mgr.get_session_id()
1219
1079
  self.conclude_state = None
1220
1080
  self.init_time = None
1221
- self.data_parallel_rank = data_parallel_rank
1081
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
1222
1082
 
1223
- if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
1224
- (
1225
- self.prefill_attn_tp_size,
1226
- self.prefill_dp_size,
1227
- self.prefill_pp_size,
1228
- ) = self._get_prefill_parallel_info_from_server()
1229
- if (
1230
- self.prefill_attn_tp_size is None
1231
- or self.prefill_dp_size is None
1232
- or self.prefill_pp_size is None
1233
- ):
1234
- self.kv_mgr.record_failure(
1235
- self.bootstrap_room,
1236
- f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
1237
- )
1238
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
1239
- return
1240
- else:
1241
- logger.debug(
1242
- f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
1243
- )
1244
- self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
1245
- self.prefill_attn_tp_size
1246
- )
1247
- self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
1248
- self.prefill_dp_size
1249
- )
1250
- self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
1251
- self.prefill_pp_size
1252
- )
1253
- else:
1254
- self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
1255
- self.bootstrap_addr
1256
- ]
1257
- self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
1258
- self.bootstrap_addr
1259
- ]
1260
- self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
1261
- self.bootstrap_addr
1262
- ]
1263
-
1264
- # Currently, we don't allow prefill instance and decode instance to
1265
- # have different TP sizes per DP rank, except for models using MLA.
1266
- if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
1267
- self.target_tp_rank = (
1268
- self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
1269
- )
1270
- self.required_dst_info_num = 1
1271
- self.required_prefill_response_num = 1 * (
1272
- self.prefill_pp_size // self.kv_mgr.pp_size
1273
- )
1274
- self.target_tp_ranks = [self.target_tp_rank]
1275
- elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
1276
- if not self.kv_mgr.is_mla_backend:
1277
- logger.warning_once(
1278
- "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
1279
- )
1280
- self.target_tp_rank = (
1281
- self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
1282
- ) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
1283
- self.required_dst_info_num = (
1284
- self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
1285
- )
1286
- self.required_prefill_response_num = 1 * (
1287
- self.prefill_pp_size // self.kv_mgr.pp_size
1288
- )
1289
- self.target_tp_ranks = [self.target_tp_rank]
1290
- else:
1291
- if not self.kv_mgr.is_mla_backend:
1292
- logger.warning_once(
1293
- "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
1294
- )
1295
- # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
1296
- self.target_tp_ranks = [
1297
- rank
1298
- for rank in range(
1299
- (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
1300
- * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
1301
- (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
1302
- * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
1303
- )
1304
- ]
1305
-
1306
- # For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
1307
- # multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
1308
- # or the KVPoll will never be set correctly
1309
- self.target_tp_rank = self.target_tp_ranks[0]
1310
- self.required_dst_info_num = 1
1311
- if self.kv_mgr.is_mla_backend:
1312
- self.required_prefill_response_num = (
1313
- self.prefill_pp_size // self.kv_mgr.pp_size
1314
- )
1315
- else:
1316
- self.required_prefill_response_num = (
1317
- self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
1318
- ) * (self.prefill_pp_size // self.kv_mgr.pp_size)
1319
-
1320
- if self.data_parallel_rank is not None:
1321
- logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
1322
- self.target_dp_group = self.data_parallel_rank
1323
- else:
1324
- self.target_dp_group = bootstrap_room % self.prefill_dp_size
1325
-
1326
- self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
1327
- self.required_prefill_response_num
1328
- )
1329
- # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
1330
- bootstrap_key = (
1331
- f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
1332
- )
1333
-
1334
- if bootstrap_key not in self.kv_mgr.connection_pool:
1335
- bootstrap_infos = []
1336
- for target_tp_rank in self.target_tp_ranks:
1337
- for target_pp_rank in range(self.prefill_pp_size):
1338
- bootstrap_info = self._get_bootstrap_info_from_server(
1339
- target_tp_rank, self.target_dp_group, target_pp_rank
1340
- )
1341
- if bootstrap_info is not None:
1342
- if self.kv_mgr.is_mla_backend:
1343
- # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
1344
- bootstrap_info["is_dummy"] = not bool(
1345
- target_tp_rank == self.target_tp_rank
1346
- or self.target_tp_rank is None
1347
- )
1348
- else:
1349
- # For non-MLA: all target_tp_ranks are selected real ranks
1350
- bootstrap_info["is_dummy"] = False
1351
- logger.debug(
1352
- f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
1353
- )
1354
- bootstrap_infos.append(bootstrap_info)
1355
- else:
1356
- self.kv_mgr.record_failure(
1357
- self.bootstrap_room,
1358
- f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
1359
- )
1360
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
1361
- return
1362
-
1363
- self.bootstrap_infos = bootstrap_infos
1364
- self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
1365
-
1366
- # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
1367
- self._register_kv_args()
1368
- else:
1369
- self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
1370
-
1371
- assert len(self.bootstrap_infos) > 0
1372
1083
  self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
1373
1084
  self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
1374
1085
 
@@ -1391,29 +1102,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
1391
1102
  logger.error(f"Error fetching prefill info from bootstrap: {e}")
1392
1103
  return None
1393
1104
 
1394
- def _get_prefill_parallel_info_from_server(
1395
- self,
1396
- ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
1397
- """Fetch the prefill parallel info from the bootstrap server."""
1398
- try:
1399
- url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
1400
- response = requests.get(url)
1401
- if response.status_code == 200:
1402
- prefill_parallel_info = response.json()
1403
- return (
1404
- int(prefill_parallel_info["prefill_attn_tp_size"]),
1405
- int(prefill_parallel_info["prefill_dp_size"]),
1406
- int(prefill_parallel_info["prefill_pp_size"]),
1407
- )
1408
- else:
1409
- logger.error(
1410
- f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
1411
- )
1412
- return None, None, None
1413
- except Exception as e:
1414
- logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
1415
- return None, None, None
1416
-
1417
1105
  def _register_kv_args(self):
1418
1106
  for bootstrap_info in self.bootstrap_infos:
1419
1107
  packed_kv_data_ptrs = b"".join(
@@ -1445,28 +1133,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
1445
1133
  ]
1446
1134
  )
1447
1135
 
1448
- @classmethod
1449
- def _connect(cls, endpoint: str, is_ipv6: bool = False):
1450
- with cls._global_lock:
1451
- if endpoint not in cls._socket_cache:
1452
- sock = cls._ctx.socket(zmq.PUSH)
1453
- if is_ipv6:
1454
- sock.setsockopt(zmq.IPV6, 1)
1455
- sock.connect(endpoint)
1456
- cls._socket_cache[endpoint] = sock
1457
- cls._socket_locks[endpoint] = threading.Lock()
1458
- return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
1459
-
1460
- @classmethod
1461
- def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
1462
- ip_address = bootstrap_info["rank_ip"]
1463
- port = bootstrap_info["rank_port"]
1464
- is_ipv6_address = is_valid_ipv6_address(ip_address)
1465
- sock, lock = cls._connect(
1466
- format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
1467
- )
1468
- return sock, lock
1469
-
1470
1136
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
1471
1137
  for bootstrap_info in self.bootstrap_infos:
1472
1138
  sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
@@ -1544,153 +1210,5 @@ class MooncakeKVReceiver(BaseKVReceiver):
1544
1210
  self.conclude_state = KVPoll.Failed
1545
1211
 
1546
1212
 
1547
- class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1548
- def __init__(self, port: int):
1549
- self.port = port
1550
- self.app = web.Application()
1551
- self.store = dict()
1552
- self.lock = asyncio.Lock()
1553
- self._setup_routes()
1554
- self.pp_size = None
1555
- self.attn_tp_size = None
1556
- self.dp_size = None
1557
- self.prefill_port_table: Dict[
1558
- int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
1559
- ] = {}
1560
-
1561
- # Start bootstrap server
1562
- self.thread = threading.Thread(target=self._run_server, daemon=True)
1563
- self.run()
1564
-
1565
- def run(self):
1566
- self.thread.start()
1567
-
1568
- def _setup_routes(self):
1569
- self.app.router.add_route("*", "/route", self._handle_route)
1570
- self.app.router.add_get("/health", self._handle_health_check)
1571
-
1572
- async def _handle_health_check(self, request):
1573
- return web.Response(text="OK", status=200)
1574
-
1575
- async def _handle_route(self, request: web.Request):
1576
- method = request.method
1577
- if method == "PUT":
1578
- return await self._handle_route_put(request)
1579
- elif method == "GET":
1580
- return await self._handle_route_get(request)
1581
- else:
1582
- return web.Response(
1583
- text="Method not allowed", status=405, content_type="application/json"
1584
- )
1585
-
1586
- async def _handle_route_put(self, request: web.Request):
1587
- data = await request.json()
1588
- role = data["role"]
1589
- attn_tp_size = data["attn_tp_size"]
1590
- attn_tp_rank = data["attn_tp_rank"]
1591
- attn_dp_size = data["attn_dp_size"]
1592
- attn_dp_rank = data["attn_dp_rank"]
1593
- pp_size = data["pp_size"]
1594
- pp_rank = data["pp_rank"]
1595
- system_dp_size = data["system_dp_size"]
1596
- system_dp_rank = data["system_dp_rank"]
1597
- rank_ip = data["rank_ip"]
1598
- rank_port = int(data["rank_port"])
1599
-
1600
- if self.attn_tp_size is None:
1601
- self.attn_tp_size = attn_tp_size
1602
-
1603
- if self.dp_size is None:
1604
- self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
1605
-
1606
- if self.pp_size is None:
1607
- self.pp_size = pp_size
1608
-
1609
- if role == "Prefill":
1610
- if system_dp_size == 1:
1611
- dp_group = attn_dp_rank
1612
- else:
1613
- dp_group = system_dp_rank
1614
-
1615
- # Add lock to make sure thread-safe
1616
- async with self.lock:
1617
- if dp_group not in self.prefill_port_table:
1618
- self.prefill_port_table[dp_group] = {}
1619
- if attn_tp_rank not in self.prefill_port_table[dp_group]:
1620
- self.prefill_port_table[dp_group][attn_tp_rank] = {}
1621
-
1622
- self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
1623
- "rank_ip": rank_ip,
1624
- "rank_port": rank_port,
1625
- }
1626
- logger.debug(
1627
- f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
1628
- )
1629
-
1630
- return web.Response(text="OK", status=200)
1631
-
1632
- async def _handle_route_get(self, request: web.Request):
1633
- engine_rank = request.query.get("engine_rank")
1634
- target_dp_group = request.query.get("target_dp_group")
1635
- target_pp_rank = request.query.get("target_pp_rank")
1636
- if not engine_rank or not target_dp_group or not target_pp_rank:
1637
- return web.Response(text="Missing inputs for bootstrap server.", status=400)
1638
-
1639
- # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
1640
- if (
1641
- int(engine_rank) == -1
1642
- and int(target_dp_group) == -1
1643
- and int(target_pp_rank) == -1
1644
- ):
1645
- prefill_parallel_info = {
1646
- "prefill_attn_tp_size": self.attn_tp_size,
1647
- "prefill_dp_size": self.dp_size,
1648
- "prefill_pp_size": self.pp_size,
1649
- }
1650
- return web.json_response(prefill_parallel_info, status=200)
1651
-
1652
- # Find corresponding prefill info
1653
- async with self.lock:
1654
- bootstrap_info = self.prefill_port_table[int(target_dp_group)][
1655
- int(engine_rank)
1656
- ][int(target_pp_rank)]
1657
-
1658
- if bootstrap_info is not None:
1659
- return web.json_response(bootstrap_info, status=200)
1660
- else:
1661
- return web.Response(text="Bootstrap info not Found", status=404)
1662
-
1663
- def _run_server(self):
1664
- try:
1665
- # Event Loop
1666
- self._loop = asyncio.new_event_loop()
1667
- asyncio.set_event_loop(self._loop)
1668
-
1669
- access_log = None
1670
- if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
1671
- access_log = self.app.logger
1672
-
1673
- self._runner = web.AppRunner(self.app, access_log=access_log)
1674
- self._loop.run_until_complete(self._runner.setup())
1675
-
1676
- site = web.TCPSite(self._runner, port=self.port)
1677
- self._loop.run_until_complete(site.start())
1678
- self._loop.run_forever()
1679
- except Exception as e:
1680
- logger.error(f"Server error: {str(e)}")
1681
- finally:
1682
- # Cleanup
1683
- self._loop.run_until_complete(self._runner.cleanup())
1684
- self._loop.close()
1685
-
1686
- def close(self):
1687
- """Shutdown"""
1688
- if self._loop is not None and self._loop.is_running():
1689
- self._loop.call_soon_threadsafe(self._loop.stop)
1690
- logger.info("Stopping server loop...")
1691
-
1692
- if self.thread.is_alive():
1693
- self.thread.join(timeout=2)
1694
- logger.info("Server thread stopped")
1695
-
1696
- def poll(self) -> KVPoll: ...
1213
+ class MooncakeKVBootstrapServer(CommonKVBootstrapServer):
1214
+ pass