sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
4
3
  import concurrent.futures
5
4
  import ctypes
6
5
  import dataclasses
7
6
  import logging
8
7
  import os
9
- import queue
10
- import socket
11
8
  import struct
12
9
  import threading
13
10
  import time
14
11
  from collections import defaultdict
15
- from functools import cache
16
- from typing import Dict, List, Optional, Tuple, Union
12
+ from typing import Dict, List, Optional, Set, Tuple
17
13
 
18
14
  import numpy as np
19
15
  import numpy.typing as npt
20
16
  import requests
21
17
  import zmq
22
- from aiohttp import web
23
-
24
- from sglang.srt.disaggregation.base.conn import (
25
- BaseKVBootstrapServer,
26
- BaseKVManager,
27
- BaseKVReceiver,
28
- BaseKVSender,
29
- KVArgs,
30
- KVPoll,
18
+
19
+ from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
20
+ from sglang.srt.disaggregation.common.conn import (
21
+ CommonKVBootstrapServer,
22
+ CommonKVManager,
23
+ CommonKVReceiver,
24
+ CommonKVSender,
31
25
  )
32
26
  from sglang.srt.disaggregation.common.utils import (
33
27
  FastQueue,
@@ -35,23 +29,12 @@ from sglang.srt.disaggregation.common.utils import (
35
29
  )
36
30
  from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
37
31
  from sglang.srt.disaggregation.utils import DisaggregationMode
38
- from sglang.srt.distributed import get_pp_group
39
- from sglang.srt.layers.dp_attention import (
40
- get_attention_dp_rank,
41
- get_attention_dp_size,
42
- get_attention_tp_rank,
43
- get_attention_tp_size,
44
- )
45
32
  from sglang.srt.server_args import ServerArgs
46
33
  from sglang.srt.utils import (
47
34
  format_tcp_address,
48
35
  get_bool_env_var,
49
- get_free_port,
50
36
  get_int_env_var,
51
- get_ip,
52
- get_local_ip_auto,
53
37
  is_valid_ipv6_address,
54
- maybe_wrap_ipv6_address,
55
38
  )
56
39
 
57
40
  logger = logging.getLogger(__name__)
@@ -75,6 +58,7 @@ class TransferKVChunk:
75
58
  index_slice: slice
76
59
  is_last: bool
77
60
  prefill_aux_index: Optional[int]
61
+ state_indices: Optional[List[int]]
78
62
 
79
63
 
80
64
  # decode
@@ -86,6 +70,7 @@ class TransferInfo:
86
70
  mooncake_session_id: str
87
71
  dst_kv_indices: npt.NDArray[np.int32]
88
72
  dst_aux_index: int
73
+ dst_state_indices: List[int]
89
74
  required_dst_info_num: int
90
75
  is_dummy: bool
91
76
 
@@ -95,9 +80,14 @@ class TransferInfo:
95
80
  is_dummy = True
96
81
  dst_kv_indices = np.array([], dtype=np.int32)
97
82
  dst_aux_index = None
83
+ dst_state_indices = []
98
84
  else:
99
85
  dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
100
86
  dst_aux_index = int(msg[5].decode("ascii"))
87
+ if msg[6] == b"":
88
+ dst_state_indices = []
89
+ else:
90
+ dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32))
101
91
  is_dummy = False
102
92
  return cls(
103
93
  room=int(msg[0].decode("ascii")),
@@ -106,7 +96,8 @@ class TransferInfo:
106
96
  mooncake_session_id=msg[3].decode("ascii"),
107
97
  dst_kv_indices=dst_kv_indices,
108
98
  dst_aux_index=dst_aux_index,
109
- required_dst_info_num=int(msg[6].decode("ascii")),
99
+ dst_state_indices=dst_state_indices,
100
+ required_dst_info_num=int(msg[7].decode("ascii")),
110
101
  is_dummy=is_dummy,
111
102
  )
112
103
 
@@ -120,6 +111,7 @@ class KVArgsRegisterInfo:
120
111
  mooncake_session_id: str
121
112
  dst_kv_ptrs: list[int]
122
113
  dst_aux_ptrs: list[int]
114
+ dst_state_data_ptrs: list[int]
123
115
  dst_tp_rank: int
124
116
  dst_attn_tp_size: int
125
117
  dst_kv_item_len: int
@@ -133,9 +125,10 @@ class KVArgsRegisterInfo:
133
125
  mooncake_session_id=msg[3].decode("ascii"),
134
126
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
135
127
  dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
136
- dst_tp_rank=int(msg[6].decode("ascii")),
137
- dst_attn_tp_size=int(msg[7].decode("ascii")),
138
- dst_kv_item_len=int(msg[8].decode("ascii")),
128
+ dst_state_data_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
129
+ dst_tp_rank=int(msg[7].decode("ascii")),
130
+ dst_attn_tp_size=int(msg[8].decode("ascii")),
131
+ dst_kv_item_len=int(msg[9].decode("ascii")),
139
132
  )
140
133
 
141
134
 
@@ -159,7 +152,7 @@ class AuxDataCodec:
159
152
  return
160
153
 
161
154
 
162
- class MooncakeKVManager(BaseKVManager):
155
+ class MooncakeKVManager(CommonKVManager):
163
156
  AUX_DATA_HEADER = b"AUX_DATA"
164
157
 
165
158
  def __init__(
@@ -169,48 +162,19 @@ class MooncakeKVManager(BaseKVManager):
169
162
  server_args: ServerArgs,
170
163
  is_mla_backend: Optional[bool] = False,
171
164
  ):
172
- self.kv_args = args
173
- self.local_ip = get_local_ip_auto()
174
- self.is_mla_backend = is_mla_backend
175
- self.disaggregation_mode = disaggregation_mode
165
+ super().__init__(args, disaggregation_mode, server_args, is_mla_backend)
176
166
  self.init_engine()
177
- # for p/d multi node infer
178
- self.bootstrap_host = server_args.host
179
- self.bootstrap_port = server_args.disaggregation_bootstrap_port
180
- self.dist_init_addr = server_args.dist_init_addr
181
- self.attn_tp_size = get_attention_tp_size()
182
- self.attn_tp_rank = get_attention_tp_rank()
183
- self.attn_dp_size = get_attention_dp_size()
184
- self.attn_dp_rank = get_attention_dp_rank()
185
- self.system_dp_size = (
186
- 1 if server_args.enable_dp_attention else server_args.dp_size
187
- )
188
- self.system_dp_rank = (
189
- self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
190
- )
191
- self.pp_size = server_args.pp_size
192
- self.pp_rank = self.kv_args.pp_rank
193
- self.request_status: Dict[int, KVPoll] = {}
194
- self.rank_port = None
195
- self.server_socket = zmq.Context().socket(zmq.PULL)
196
- if is_valid_ipv6_address(self.local_ip):
197
- self.server_socket.setsockopt(zmq.IPV6, 1)
198
-
199
167
  self.register_buffer_to_engine()
200
168
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
201
- self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
202
- self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
203
169
  self.start_prefill_thread()
204
- self._register_to_bootstrap()
205
170
  self.session_failures = defaultdict(int)
206
171
  self.failed_sessions = set()
207
172
  self.session_lock = threading.Lock()
208
- self.pp_group = get_pp_group()
209
173
  # Determine the number of threads to use for kv sender
210
174
  cpu_count = os.cpu_count()
211
175
  transfer_thread_pool_size = get_int_env_var(
212
176
  "SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
213
- min(max(4, int(0.75 * cpu_count) // 8), 12),
177
+ min(max(4, int(0.5 * cpu_count) // 8), 12),
214
178
  )
215
179
  transfer_queue_size = get_int_env_var("SGLANG_DISAGGREGATION_QUEUE_SIZE", 4)
216
180
  self.transfer_queues: List[FastQueue] = [
@@ -245,8 +209,6 @@ class MooncakeKVManager(BaseKVManager):
245
209
  self.session_pool = defaultdict(requests.Session)
246
210
  self.session_pool_lock = threading.Lock()
247
211
  self.addr_to_rooms_tracker = defaultdict(set)
248
- self.connection_lock = threading.Lock()
249
- self.required_prefill_response_num_table: Dict[int, int] = {}
250
212
  self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
251
213
  # Heartbeat interval should be at least 2 seconds
252
214
  self.heartbeat_interval = max(
@@ -257,20 +219,12 @@ class MooncakeKVManager(BaseKVManager):
257
219
  get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
258
220
  )
259
221
  self.start_decode_thread()
260
- self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
261
- self.prefill_attn_tp_size_table: Dict[str, int] = {}
262
- self.prefill_dp_size_table: Dict[str, int] = {}
263
- self.prefill_pp_size_table: Dict[str, int] = {}
264
222
  # If a timeout happens on the decode side, it means decode instances
265
223
  # fail to receive the KV Cache transfer done signal after bootstrapping.
266
224
  # These timeout requests should be aborted to release the tree cache.
267
225
  self.waiting_timeout = get_int_env_var(
268
226
  "SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300
269
227
  )
270
- else:
271
- raise ValueError(
272
- f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
273
- )
274
228
 
275
229
  self.failure_records: Dict[int, str] = {}
276
230
  self.failure_lock = threading.Lock()
@@ -295,13 +249,11 @@ class MooncakeKVManager(BaseKVManager):
295
249
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
296
250
  )
297
251
 
298
- @cache
299
- def _connect(self, endpoint: str, is_ipv6: bool = False):
300
- socket = zmq.Context().socket(zmq.PUSH)
301
- if is_ipv6:
302
- socket.setsockopt(zmq.IPV6, 1)
303
- socket.connect(endpoint)
304
- return socket
252
+ # Batch register state/extra pool data buffers
253
+ if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
254
+ self.engine.batch_register(
255
+ self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
256
+ )
305
257
 
306
258
  def _transfer_data(self, mooncake_session_id, transfer_blocks):
307
259
  if not transfer_blocks:
@@ -312,62 +264,60 @@ class MooncakeKVManager(BaseKVManager):
312
264
  mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
313
265
  )
314
266
 
315
- def send_kvcache(
267
+ def _send_kvcache_generic(
316
268
  self,
317
269
  mooncake_session_id: str,
318
- prefill_kv_indices: npt.NDArray[np.int32],
319
- dst_kv_ptrs: list[int],
320
- dst_kv_indices: npt.NDArray[np.int32],
270
+ src_data_ptrs: list[int],
271
+ dst_data_ptrs: list[int],
272
+ item_lens: list[int],
273
+ prefill_data_indices: npt.NDArray[np.int32],
274
+ dst_data_indices: npt.NDArray[np.int32],
321
275
  executor: concurrent.futures.ThreadPoolExecutor,
322
- ):
323
- # Group by indices
276
+ ) -> int:
277
+ """
278
+ Generic KV cache transfer supporting both MHA and MLA architectures.
279
+ This method is used by both send_kvcache (full pool) and maybe_send_extra.
280
+ """
281
+ # Group by indices for optimization
324
282
  prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
325
- prefill_kv_indices, dst_kv_indices
283
+ prefill_data_indices, dst_data_indices
326
284
  )
327
285
 
328
286
  layers_params = None
329
287
 
330
288
  # pp is not supported on the decode side yet
331
- start_layer = self.kv_args.prefill_start_layer
332
- end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
333
289
  if self.is_mla_backend:
334
- src_kv_ptrs = self.kv_args.kv_data_ptrs
335
- layers_per_pp_stage = len(src_kv_ptrs)
336
- dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
337
- kv_item_len = self.kv_args.kv_item_lens[0]
290
+ src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
291
+ self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
292
+ )
293
+ kv_item_len = item_lens[0]
338
294
  layers_params = [
339
295
  (
340
296
  src_kv_ptrs[layer_id],
341
297
  dst_kv_ptrs[layer_id],
342
298
  kv_item_len,
343
299
  )
344
- for layer_id in range(layers_per_pp_stage)
300
+ for layer_id in range(layers_current_pp_stage)
345
301
  ]
346
302
  else:
347
- num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
348
- dst_num_total_layers = num_kv_layers * self.pp_size
349
- src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
350
- src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
351
- layers_per_pp_stage = len(src_k_ptrs)
352
- dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
353
- dst_v_ptrs = dst_kv_ptrs[
354
- dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
355
- ]
356
- kv_item_len = self.kv_args.kv_item_lens[0]
303
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
304
+ self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
305
+ )
306
+ kv_item_len = item_lens[0]
357
307
  layers_params = [
358
308
  (
359
309
  src_k_ptrs[layer_id],
360
310
  dst_k_ptrs[layer_id],
361
311
  kv_item_len,
362
312
  )
363
- for layer_id in range(layers_per_pp_stage)
313
+ for layer_id in range(layers_current_pp_stage)
364
314
  ] + [
365
315
  (
366
316
  src_v_ptrs[layer_id],
367
317
  dst_v_ptrs[layer_id],
368
318
  kv_item_len,
369
319
  )
370
- for layer_id in range(layers_per_pp_stage)
320
+ for layer_id in range(layers_current_pp_stage)
371
321
  ]
372
322
  assert layers_params is not None
373
323
 
@@ -417,6 +367,24 @@ class MooncakeKVManager(BaseKVManager):
417
367
 
418
368
  return 0
419
369
 
370
+ def send_kvcache(
371
+ self,
372
+ mooncake_session_id: str,
373
+ prefill_kv_indices: npt.NDArray[np.int32],
374
+ dst_kv_ptrs: list[int],
375
+ dst_kv_indices: npt.NDArray[np.int32],
376
+ executor: concurrent.futures.ThreadPoolExecutor,
377
+ ):
378
+ return self._send_kvcache_generic(
379
+ mooncake_session_id=mooncake_session_id,
380
+ src_data_ptrs=self.kv_args.kv_data_ptrs,
381
+ dst_data_ptrs=dst_kv_ptrs,
382
+ item_lens=self.kv_args.kv_item_lens,
383
+ prefill_data_indices=prefill_kv_indices,
384
+ dst_data_indices=dst_kv_indices,
385
+ executor=executor,
386
+ )
387
+
420
388
  def send_kvcache_slice(
421
389
  self,
422
390
  mooncake_session_id: str,
@@ -465,18 +433,9 @@ class MooncakeKVManager(BaseKVManager):
465
433
  num_heads_to_send = dst_heads_per_rank
466
434
  dst_head_start_offset = 0
467
435
 
468
- # pp is not supported on the decode side yet
469
- num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
470
- dst_num_total_layers = num_kv_layers * self.pp_size
471
- src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
472
- src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
473
- layers_per_pp_stage = len(src_k_ptrs)
474
- start_layer = self.pp_rank * layers_per_pp_stage
475
- end_layer = start_layer + layers_per_pp_stage
476
- dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
477
- dst_v_ptrs = dst_kv_ptrs[
478
- dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
479
- ]
436
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
437
+ self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
438
+ )
480
439
 
481
440
  # Calculate precise byte offset and length for the sub-slice within the token
482
441
  src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
@@ -502,7 +461,7 @@ class MooncakeKVManager(BaseKVManager):
502
461
  dst_head_slice_offset,
503
462
  heads_bytes_per_token_to_send,
504
463
  )
505
- for layer_id in range(layers_per_pp_stage)
464
+ for layer_id in range(layers_current_pp_stage)
506
465
  ] + [
507
466
  (
508
467
  src_v_ptrs[layer_id],
@@ -513,7 +472,7 @@ class MooncakeKVManager(BaseKVManager):
513
472
  dst_head_slice_offset,
514
473
  heads_bytes_per_token_to_send,
515
474
  )
516
- for layer_id in range(layers_per_pp_stage)
475
+ for layer_id in range(layers_current_pp_stage)
517
476
  ]
518
477
 
519
478
  def process_layer_tp_aware(layer_params):
@@ -654,6 +613,79 @@ class MooncakeKVManager(BaseKVManager):
654
613
  ]
655
614
  )
656
615
 
616
+ def _handle_aux_data(self, msg: List[bytes]):
617
+ """Handle AUX_DATA messages received by the decode thread."""
618
+ room = int(msg[1].decode("ascii"))
619
+ buffer_index = int(msg[2].decode("ascii"))
620
+ aux_index = int(msg[3].decode("ascii"))
621
+ data_length = struct.unpack(">I", msg[4])[0]
622
+ data = msg[5]
623
+
624
+ if len(data) != data_length:
625
+ logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
626
+ return
627
+
628
+ AuxDataCodec.deserialize_data_to_buffer(
629
+ self.kv_args, buffer_index, aux_index, data
630
+ )
631
+
632
+ logger.debug(
633
+ f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
634
+ )
635
+
636
+ def maybe_send_extra(
637
+ self,
638
+ req: TransferInfo,
639
+ prefill_state_indices: list[int],
640
+ dst_state_data_ptrs: list[int],
641
+ executor: concurrent.futures.ThreadPoolExecutor,
642
+ ):
643
+ """Send state or extra pool data with type-specific handling."""
644
+ state_type = getattr(self.kv_args, "state_type", "none")
645
+
646
+ if state_type == "mamba":
647
+ return self._send_mamba_state(
648
+ req,
649
+ prefill_state_indices,
650
+ dst_state_data_ptrs,
651
+ )
652
+ elif state_type in ["swa", "nsa"]:
653
+ # Reuse _send_kvcache_generic interface to send extra pool data
654
+ prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32)
655
+ dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32)
656
+ return self._send_kvcache_generic(
657
+ mooncake_session_id=req.mooncake_session_id,
658
+ src_data_ptrs=self.kv_args.state_data_ptrs,
659
+ dst_data_ptrs=dst_state_data_ptrs,
660
+ item_lens=self.kv_args.state_item_lens,
661
+ prefill_data_indices=prefill_state_indices,
662
+ dst_data_indices=dst_state_indices,
663
+ executor=executor,
664
+ )
665
+ else:
666
+ return 0
667
+
668
+ def _send_mamba_state(
669
+ self,
670
+ req: TransferInfo,
671
+ prefill_mamba_index: list[int],
672
+ dst_state_data_ptrs: list[int],
673
+ ):
674
+ """Transfer Mamba states."""
675
+ assert len(prefill_mamba_index) == 1, "Mamba should have single state index"
676
+
677
+ transfer_blocks = []
678
+ prefill_state_data_ptrs = self.kv_args.state_data_ptrs
679
+ prefill_state_item_lens = self.kv_args.state_item_lens
680
+
681
+ for i, dst_state_ptr in enumerate(dst_state_data_ptrs):
682
+ length = prefill_state_item_lens[i]
683
+ src_addr = prefill_state_data_ptrs[i] + length * int(prefill_mamba_index[0])
684
+ dst_addr = dst_state_ptr + length * int(req.dst_state_indices[0])
685
+ transfer_blocks.append((src_addr, dst_addr, length))
686
+
687
+ return self._transfer_data(req.mooncake_session_id, transfer_blocks)
688
+
657
689
  def sync_status_to_decode_endpoint(
658
690
  self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
659
691
  ):
@@ -763,6 +795,22 @@ class MooncakeKVManager(BaseKVManager):
763
795
  break
764
796
 
765
797
  if kv_chunk.is_last:
798
+ if kv_chunk.state_indices is not None:
799
+ if not self.is_mla_backend and (
800
+ self.attn_tp_size
801
+ != target_rank_registration_info.dst_attn_tp_size
802
+ ):
803
+ raise RuntimeError(
804
+ f"PD Disaggregation does NOT support PD different TP sizes for non-MLA hybrid models yet."
805
+ )
806
+
807
+ self.maybe_send_extra(
808
+ req,
809
+ kv_chunk.state_indices,
810
+ target_rank_registration_info.dst_state_data_ptrs,
811
+ executor,
812
+ )
813
+
766
814
  if self.pp_group.is_last_rank:
767
815
  # Only the last chunk we need to send the aux data
768
816
  ret = self.send_aux(
@@ -802,11 +850,7 @@ class MooncakeKVManager(BaseKVManager):
802
850
  f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
803
851
  )
804
852
 
805
- def _bind_server_socket(self):
806
- self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
807
-
808
853
  def start_prefill_thread(self):
809
- self.rank_port = get_free_port()
810
854
  self._bind_server_socket()
811
855
 
812
856
  def bootstrap_thread():
@@ -830,7 +874,7 @@ class MooncakeKVManager(BaseKVManager):
830
874
  )
831
875
  continue
832
876
  else:
833
- required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
877
+ required_dst_info_num = int(waiting_req_bytes[7].decode("ascii"))
834
878
  room = int(room)
835
879
  if room not in self.transfer_infos:
836
880
  self.transfer_infos[room] = {}
@@ -844,28 +888,7 @@ class MooncakeKVManager(BaseKVManager):
844
888
 
845
889
  threading.Thread(target=bootstrap_thread).start()
846
890
 
847
- def _handle_aux_data(self, msg: List[bytes]):
848
- """Handle AUX_DATA messages received by the decode thread."""
849
- room = int(msg[1].decode("ascii"))
850
- buffer_index = int(msg[2].decode("ascii"))
851
- aux_index = int(msg[3].decode("ascii"))
852
- data_length = struct.unpack(">I", msg[4])[0]
853
- data = msg[5]
854
-
855
- if len(data) != data_length:
856
- logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
857
- return
858
-
859
- AuxDataCodec.deserialize_data_to_buffer(
860
- self.kv_args, buffer_index, aux_index, data
861
- )
862
-
863
- logger.debug(
864
- f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
865
- )
866
-
867
891
  def start_decode_thread(self):
868
- self.rank_port = get_free_port()
869
892
  self._bind_server_socket()
870
893
 
871
894
  def decode_thread():
@@ -962,6 +985,7 @@ class MooncakeKVManager(BaseKVManager):
962
985
  index_slice: slice,
963
986
  is_last: bool,
964
987
  aux_index: Optional[int] = None,
988
+ state_indices: Optional[List[int]] = None,
965
989
  ):
966
990
  assert self.disaggregation_mode == DisaggregationMode.PREFILL
967
991
  assert not is_last or (is_last and aux_index is not None)
@@ -995,6 +1019,7 @@ class MooncakeKVManager(BaseKVManager):
995
1019
  index_slice=index_slice,
996
1020
  is_last=is_last,
997
1021
  prefill_aux_index=aux_index,
1022
+ state_indices=state_indices,
998
1023
  )
999
1024
  )
1000
1025
 
@@ -1020,51 +1045,6 @@ class MooncakeKVManager(BaseKVManager):
1020
1045
  def get_session_id(self):
1021
1046
  return self.engine.get_session_id()
1022
1047
 
1023
- def _register_to_bootstrap(self):
1024
- """Register KVSender to bootstrap server via HTTP POST."""
1025
- if self.dist_init_addr:
1026
- # multi node case: bootstrap server's host is dist_init_addr
1027
- if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
1028
- if self.dist_init_addr.endswith("]"):
1029
- host = self.dist_init_addr
1030
- else:
1031
- host, _ = self.dist_init_addr.rsplit(":", 1)
1032
- else:
1033
- host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
1034
- else:
1035
- # single node case: bootstrap server's host is same as http server's host
1036
- host = self.bootstrap_host
1037
- host = maybe_wrap_ipv6_address(host)
1038
-
1039
- bootstrap_server_url = f"{host}:{self.bootstrap_port}"
1040
- url = f"http://{bootstrap_server_url}/route"
1041
- payload = {
1042
- "role": "Prefill",
1043
- "attn_tp_size": self.attn_tp_size,
1044
- "attn_tp_rank": self.attn_tp_rank,
1045
- "attn_dp_size": self.attn_dp_size,
1046
- "attn_dp_rank": self.attn_dp_rank,
1047
- "pp_size": self.pp_size,
1048
- "pp_rank": self.pp_rank,
1049
- "system_dp_size": self.system_dp_size,
1050
- "system_dp_rank": self.system_dp_rank,
1051
- "rank_ip": self.local_ip,
1052
- "rank_port": self.rank_port,
1053
- }
1054
-
1055
- try:
1056
- response = requests.put(url, json=payload, timeout=5)
1057
- if response.status_code == 200:
1058
- logger.debug("Prefill successfully registered to bootstrap server.")
1059
- else:
1060
- logger.error(
1061
- f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
1062
- )
1063
- except Exception as e:
1064
- logger.error(
1065
- f"Prefill instance failed to register to bootstrap server: {e}"
1066
- )
1067
-
1068
1048
  def _handle_node_failure(self, failed_bootstrap_addr):
1069
1049
  with self.connection_lock:
1070
1050
  keys_to_remove = [
@@ -1103,7 +1083,7 @@ class MooncakeKVManager(BaseKVManager):
1103
1083
  )
1104
1084
 
1105
1085
 
1106
- class MooncakeKVSender(BaseKVSender):
1086
+ class MooncakeKVSender(CommonKVSender):
1107
1087
 
1108
1088
  def __init__(
1109
1089
  self,
@@ -1113,23 +1093,14 @@ class MooncakeKVSender(BaseKVSender):
1113
1093
  dest_tp_ranks: List[int],
1114
1094
  pp_rank: int,
1115
1095
  ):
1116
- self.kv_mgr = mgr
1117
- self.bootstrap_room = bootstrap_room
1118
- self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
1119
- self.aux_index = None
1120
- self.bootstrap_server_url = bootstrap_addr
1096
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
1121
1097
  self.conclude_state = None
1122
1098
  self.init_time = time.time()
1123
- # inner state
1124
- self.curr_idx = 0
1125
-
1126
- def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
1127
- self.num_kv_indices = num_kv_indices
1128
- self.aux_index = aux_index
1129
1099
 
1130
1100
  def send(
1131
1101
  self,
1132
1102
  kv_indices: npt.NDArray[np.int32],
1103
+ state_indices: Optional[List[int]] = None,
1133
1104
  ):
1134
1105
  index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
1135
1106
  self.curr_idx += len(kv_indices)
@@ -1149,6 +1120,7 @@ class MooncakeKVSender(BaseKVSender):
1149
1120
  index_slice,
1150
1121
  True,
1151
1122
  aux_index=self.aux_index,
1123
+ state_indices=state_indices,
1152
1124
  )
1153
1125
 
1154
1126
  def poll(self) -> KVPoll:
@@ -1203,7 +1175,7 @@ class MooncakeKVSender(BaseKVSender):
1203
1175
  self.conclude_state = KVPoll.Failed
1204
1176
 
1205
1177
 
1206
- class MooncakeKVReceiver(BaseKVReceiver):
1178
+ class MooncakeKVReceiver(CommonKVReceiver):
1207
1179
  _ctx = zmq.Context()
1208
1180
  _socket_cache = {}
1209
1181
  _socket_locks = {}
@@ -1216,166 +1188,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
1216
1188
  bootstrap_room: Optional[int] = None,
1217
1189
  prefill_dp_rank: Optional[int] = None,
1218
1190
  ):
1219
- self.bootstrap_room = bootstrap_room
1220
- self.bootstrap_addr = bootstrap_addr
1221
- self.kv_mgr = mgr
1222
- self.session_id = self.kv_mgr.get_session_id()
1223
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
1191
+ self.session_id = mgr.get_session_id()
1224
1192
  self.conclude_state = None
1225
1193
  self.init_time = None
1194
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
1226
1195
 
1227
- if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
1228
- (
1229
- self.prefill_attn_tp_size,
1230
- self.prefill_dp_size,
1231
- self.prefill_pp_size,
1232
- ) = self._get_prefill_parallel_info_from_server()
1233
- if (
1234
- self.prefill_attn_tp_size is None
1235
- or self.prefill_dp_size is None
1236
- or self.prefill_pp_size is None
1237
- ):
1238
- self.kv_mgr.record_failure(
1239
- self.bootstrap_room,
1240
- f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
1241
- )
1242
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
1243
- return
1244
- else:
1245
- logger.debug(
1246
- f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
1247
- )
1248
- self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
1249
- self.prefill_attn_tp_size
1250
- )
1251
- self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
1252
- self.prefill_dp_size
1253
- )
1254
- self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
1255
- self.prefill_pp_size
1256
- )
1257
- else:
1258
- self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
1259
- self.bootstrap_addr
1260
- ]
1261
- self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
1262
- self.bootstrap_addr
1263
- ]
1264
- self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
1265
- self.bootstrap_addr
1266
- ]
1267
-
1268
- # Currently, we don't allow prefill instance and decode instance to
1269
- # have different TP sizes per DP rank, except for models using MLA.
1270
- if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
1271
- self.target_tp_rank = (
1272
- self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
1273
- )
1274
- self.required_dst_info_num = 1
1275
- self.required_prefill_response_num = 1 * (
1276
- self.prefill_pp_size // self.kv_mgr.pp_size
1277
- )
1278
- self.target_tp_ranks = [self.target_tp_rank]
1279
- elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
1280
- if not self.kv_mgr.is_mla_backend:
1281
- logger.warning_once(
1282
- "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
1283
- )
1284
- self.target_tp_rank = (
1285
- self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
1286
- ) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
1287
- self.required_dst_info_num = (
1288
- self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
1289
- )
1290
- self.required_prefill_response_num = 1 * (
1291
- self.prefill_pp_size // self.kv_mgr.pp_size
1292
- )
1293
- self.target_tp_ranks = [self.target_tp_rank]
1294
- else:
1295
- if not self.kv_mgr.is_mla_backend:
1296
- logger.warning_once(
1297
- "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
1298
- )
1299
- # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
1300
- self.target_tp_ranks = [
1301
- rank
1302
- for rank in range(
1303
- (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
1304
- * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
1305
- (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
1306
- * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
1307
- )
1308
- ]
1309
-
1310
- # For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
1311
- # multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
1312
- # or the KVPoll will never be set correctly
1313
- self.target_tp_rank = self.target_tp_ranks[0]
1314
- self.required_dst_info_num = 1
1315
- if self.kv_mgr.is_mla_backend:
1316
- self.required_prefill_response_num = (
1317
- self.prefill_pp_size // self.kv_mgr.pp_size
1318
- )
1319
- else:
1320
- self.required_prefill_response_num = (
1321
- self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
1322
- ) * (self.prefill_pp_size // self.kv_mgr.pp_size)
1323
-
1324
- if prefill_dp_rank is not None:
1325
- logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
1326
- self.prefill_dp_rank = prefill_dp_rank
1327
- else:
1328
- self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
1329
-
1330
- # FIXME: alias here: target_dp_group -> prefill_dp_rank
1331
- self.target_dp_group = self.prefill_dp_rank
1332
-
1333
- self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
1334
- self.required_prefill_response_num
1335
- )
1336
- # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
1337
- bootstrap_key = (
1338
- f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
1339
- )
1340
-
1341
- if bootstrap_key not in self.kv_mgr.connection_pool:
1342
- bootstrap_infos = []
1343
- for target_tp_rank in self.target_tp_ranks:
1344
- for target_pp_rank in range(self.prefill_pp_size):
1345
- bootstrap_info = self._get_bootstrap_info_from_server(
1346
- target_tp_rank, self.target_dp_group, target_pp_rank
1347
- )
1348
- if bootstrap_info is not None:
1349
- if self.kv_mgr.is_mla_backend:
1350
- # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
1351
- bootstrap_info["is_dummy"] = not bool(
1352
- target_tp_rank == self.target_tp_rank
1353
- or self.target_tp_rank is None
1354
- )
1355
- else:
1356
- # For non-MLA: all target_tp_ranks are selected real ranks
1357
- bootstrap_info["is_dummy"] = False
1358
- logger.debug(
1359
- f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
1360
- )
1361
- bootstrap_infos.append(bootstrap_info)
1362
- else:
1363
- self.kv_mgr.record_failure(
1364
- self.bootstrap_room,
1365
- f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
1366
- )
1367
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
1368
- return
1369
-
1370
- self.bootstrap_infos = bootstrap_infos
1371
- self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
1372
-
1373
- # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
1374
- self._register_kv_args()
1375
- else:
1376
- self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
1377
-
1378
- assert len(self.bootstrap_infos) > 0
1379
1196
  self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
1380
1197
  self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
1381
1198
 
@@ -1398,29 +1215,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
1398
1215
  logger.error(f"Error fetching prefill info from bootstrap: {e}")
1399
1216
  return None
1400
1217
 
1401
- def _get_prefill_parallel_info_from_server(
1402
- self,
1403
- ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
1404
- """Fetch the prefill parallel info from the bootstrap server."""
1405
- try:
1406
- url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
1407
- response = requests.get(url)
1408
- if response.status_code == 200:
1409
- prefill_parallel_info = response.json()
1410
- return (
1411
- int(prefill_parallel_info["prefill_attn_tp_size"]),
1412
- int(prefill_parallel_info["prefill_dp_size"]),
1413
- int(prefill_parallel_info["prefill_pp_size"]),
1414
- )
1415
- else:
1416
- logger.error(
1417
- f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
1418
- )
1419
- return None, None, None
1420
- except Exception as e:
1421
- logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
1422
- return None, None, None
1423
-
1424
1218
  def _register_kv_args(self):
1425
1219
  for bootstrap_info in self.bootstrap_infos:
1426
1220
  packed_kv_data_ptrs = b"".join(
@@ -1429,6 +1223,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
1429
1223
  packed_aux_data_ptrs = b"".join(
1430
1224
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
1431
1225
  )
1226
+ packed_state_data_ptrs = b"".join(
1227
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs
1228
+ )
1432
1229
  # Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
1433
1230
  tp_rank = self.kv_mgr.kv_args.engine_rank
1434
1231
  kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
@@ -1446,35 +1243,27 @@ class MooncakeKVReceiver(BaseKVReceiver):
1446
1243
  self.session_id.encode("ascii"),
1447
1244
  packed_kv_data_ptrs,
1448
1245
  packed_aux_data_ptrs,
1246
+ packed_state_data_ptrs,
1449
1247
  dst_tp_rank,
1450
1248
  dst_attn_tp_size,
1451
1249
  dst_kv_item_len,
1452
1250
  ]
1453
1251
  )
1454
1252
 
1455
- @classmethod
1456
- def _connect(cls, endpoint: str, is_ipv6: bool = False):
1457
- with cls._global_lock:
1458
- if endpoint not in cls._socket_cache:
1459
- sock = cls._ctx.socket(zmq.PUSH)
1460
- if is_ipv6:
1461
- sock.setsockopt(zmq.IPV6, 1)
1462
- sock.connect(endpoint)
1463
- cls._socket_cache[endpoint] = sock
1464
- cls._socket_locks[endpoint] = threading.Lock()
1465
- return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
1466
-
1467
- @classmethod
1468
- def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
1469
- ip_address = bootstrap_info["rank_ip"]
1470
- port = bootstrap_info["rank_port"]
1471
- is_ipv6_address = is_valid_ipv6_address(ip_address)
1472
- sock, lock = cls._connect(
1473
- format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
1474
- )
1475
- return sock, lock
1253
+ def init(
1254
+ self,
1255
+ kv_indices: npt.NDArray[np.int32],
1256
+ aux_index: Optional[int] = None,
1257
+ state_indices: Optional[List[int]] = None,
1258
+ ):
1259
+ if self.bootstrap_infos is None:
1260
+ self.kv_mgr.record_failure(
1261
+ self.bootstrap_room,
1262
+ f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
1263
+ )
1264
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
1265
+ return
1476
1266
 
1477
- def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
1478
1267
  for bootstrap_info in self.bootstrap_infos:
1479
1268
  sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
1480
1269
  is_dummy = bootstrap_info["is_dummy"]
@@ -1488,6 +1277,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
1488
1277
  self.session_id.encode("ascii"),
1489
1278
  kv_indices.tobytes() if not is_dummy else b"",
1490
1279
  str(aux_index).encode("ascii") if not is_dummy else b"",
1280
+ (
1281
+ np.array(
1282
+ state_indices,
1283
+ dtype=np.int32,
1284
+ ).tobytes()
1285
+ if not is_dummy and state_indices is not None
1286
+ else b""
1287
+ ),
1491
1288
  str(self.required_dst_info_num).encode("ascii"),
1492
1289
  ]
1493
1290
  )
@@ -1551,154 +1348,5 @@ class MooncakeKVReceiver(BaseKVReceiver):
1551
1348
  self.conclude_state = KVPoll.Failed
1552
1349
 
1553
1350
 
1554
- class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1555
- def __init__(self, host: str, port: int):
1556
- self.host = host
1557
- self.port = port
1558
- self.app = web.Application()
1559
- self.store = dict()
1560
- self.lock = asyncio.Lock()
1561
- self._setup_routes()
1562
- self.pp_size = None
1563
- self.attn_tp_size = None
1564
- self.dp_size = None
1565
- self.prefill_port_table: Dict[
1566
- int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
1567
- ] = {}
1568
-
1569
- # Start bootstrap server
1570
- self.thread = threading.Thread(target=self._run_server, daemon=True)
1571
- self.run()
1572
-
1573
- def run(self):
1574
- self.thread.start()
1575
-
1576
- def _setup_routes(self):
1577
- self.app.router.add_route("*", "/route", self._handle_route)
1578
- self.app.router.add_get("/health", self._handle_health_check)
1579
-
1580
- async def _handle_health_check(self, request):
1581
- return web.Response(text="OK", status=200)
1582
-
1583
- async def _handle_route(self, request: web.Request):
1584
- method = request.method
1585
- if method == "PUT":
1586
- return await self._handle_route_put(request)
1587
- elif method == "GET":
1588
- return await self._handle_route_get(request)
1589
- else:
1590
- return web.Response(
1591
- text="Method not allowed", status=405, content_type="application/json"
1592
- )
1593
-
1594
- async def _handle_route_put(self, request: web.Request):
1595
- data = await request.json()
1596
- role = data["role"]
1597
- attn_tp_size = data["attn_tp_size"]
1598
- attn_tp_rank = data["attn_tp_rank"]
1599
- attn_dp_size = data["attn_dp_size"]
1600
- attn_dp_rank = data["attn_dp_rank"]
1601
- pp_size = data["pp_size"]
1602
- pp_rank = data["pp_rank"]
1603
- system_dp_size = data["system_dp_size"]
1604
- system_dp_rank = data["system_dp_rank"]
1605
- rank_ip = data["rank_ip"]
1606
- rank_port = int(data["rank_port"])
1607
-
1608
- if self.attn_tp_size is None:
1609
- self.attn_tp_size = attn_tp_size
1610
-
1611
- if self.dp_size is None:
1612
- self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
1613
-
1614
- if self.pp_size is None:
1615
- self.pp_size = pp_size
1616
-
1617
- if role == "Prefill":
1618
- if system_dp_size == 1:
1619
- dp_group = attn_dp_rank
1620
- else:
1621
- dp_group = system_dp_rank
1622
-
1623
- # Add lock to make sure thread-safe
1624
- async with self.lock:
1625
- if dp_group not in self.prefill_port_table:
1626
- self.prefill_port_table[dp_group] = {}
1627
- if attn_tp_rank not in self.prefill_port_table[dp_group]:
1628
- self.prefill_port_table[dp_group][attn_tp_rank] = {}
1629
-
1630
- self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
1631
- "rank_ip": rank_ip,
1632
- "rank_port": rank_port,
1633
- }
1634
- logger.debug(
1635
- f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
1636
- )
1637
-
1638
- return web.Response(text="OK", status=200)
1639
-
1640
- async def _handle_route_get(self, request: web.Request):
1641
- engine_rank = request.query.get("engine_rank")
1642
- target_dp_group = request.query.get("target_dp_group")
1643
- target_pp_rank = request.query.get("target_pp_rank")
1644
- if not engine_rank or not target_dp_group or not target_pp_rank:
1645
- return web.Response(text="Missing inputs for bootstrap server.", status=400)
1646
-
1647
- # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
1648
- if (
1649
- int(engine_rank) == -1
1650
- and int(target_dp_group) == -1
1651
- and int(target_pp_rank) == -1
1652
- ):
1653
- prefill_parallel_info = {
1654
- "prefill_attn_tp_size": self.attn_tp_size,
1655
- "prefill_dp_size": self.dp_size,
1656
- "prefill_pp_size": self.pp_size,
1657
- }
1658
- return web.json_response(prefill_parallel_info, status=200)
1659
-
1660
- # Find corresponding prefill info
1661
- async with self.lock:
1662
- bootstrap_info = self.prefill_port_table[int(target_dp_group)][
1663
- int(engine_rank)
1664
- ][int(target_pp_rank)]
1665
-
1666
- if bootstrap_info is not None:
1667
- return web.json_response(bootstrap_info, status=200)
1668
- else:
1669
- return web.Response(text="Bootstrap info not Found", status=404)
1670
-
1671
- def _run_server(self):
1672
- try:
1673
- # Event Loop
1674
- self._loop = asyncio.new_event_loop()
1675
- asyncio.set_event_loop(self._loop)
1676
-
1677
- access_log = None
1678
- if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
1679
- access_log = self.app.logger
1680
-
1681
- self._runner = web.AppRunner(self.app, access_log=access_log)
1682
- self._loop.run_until_complete(self._runner.setup())
1683
-
1684
- site = web.TCPSite(self._runner, host=self.host, port=self.port)
1685
- self._loop.run_until_complete(site.start())
1686
- self._loop.run_forever()
1687
- except Exception as e:
1688
- logger.error(f"Server error: {str(e)}")
1689
- finally:
1690
- # Cleanup
1691
- self._loop.run_until_complete(self._runner.cleanup())
1692
- self._loop.close()
1693
-
1694
- def close(self):
1695
- """Shutdown"""
1696
- if self._loop is not None and self._loop.is_running():
1697
- self._loop.call_soon_threadsafe(self._loop.stop)
1698
- logger.info("Stopping server loop...")
1699
-
1700
- if self.thread.is_alive():
1701
- self.thread.join(timeout=2)
1702
- logger.info("Server thread stopped")
1703
-
1704
- def poll(self) -> KVPoll: ...
1351
+ class MooncakeKVBootstrapServer(CommonKVBootstrapServer):
1352
+ pass