sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,37 +1,30 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
4
3
  import dataclasses
5
4
  import logging
6
- import queue
7
- import socket
5
+ import os
8
6
  import struct
9
7
  import threading
8
+ import time
10
9
  import uuid
11
10
  from collections import defaultdict
12
- from functools import cache
13
- from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
11
+ from typing import Dict, List, Optional, Set
14
12
 
15
13
  import numpy as np
16
14
  import numpy.typing as npt
17
15
  import requests
18
- import zmq
19
- from aiohttp import web
20
16
 
21
- from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll
17
+ from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
22
18
  from sglang.srt.disaggregation.common.conn import (
23
19
  CommonKVBootstrapServer,
24
20
  CommonKVManager,
25
21
  CommonKVReceiver,
22
+ CommonKVSender,
26
23
  )
27
24
  from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
28
25
  from sglang.srt.disaggregation.utils import DisaggregationMode
29
26
  from sglang.srt.server_args import ServerArgs
30
- from sglang.srt.utils import (
31
- format_tcp_address,
32
- get_local_ip_auto,
33
- is_valid_ipv6_address,
34
- )
27
+ from sglang.srt.utils import get_int_env_var
35
28
 
36
29
  logger = logging.getLogger(__name__)
37
30
 
@@ -78,6 +71,9 @@ class KVArgsRegisterInfo:
78
71
  dst_kv_ptrs: list[int]
79
72
  dst_aux_ptrs: list[int]
80
73
  gpu_id: int
74
+ decode_tp_size: int
75
+ decode_tp_rank: int
76
+ dst_kv_item_len: int
81
77
 
82
78
  @classmethod
83
79
  def from_zmq(cls, msg: List[bytes]):
@@ -90,6 +86,9 @@ class KVArgsRegisterInfo:
90
86
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
91
87
  dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
92
88
  gpu_id=int(msg[7].decode("ascii")),
89
+ decode_tp_size=int(msg[8].decode("ascii")),
90
+ decode_tp_rank=int(msg[9].decode("ascii")),
91
+ dst_kv_item_len=int(msg[10].decode("ascii")),
93
92
  )
94
93
 
95
94
 
@@ -107,8 +106,14 @@ class TransferStatus:
107
106
  def is_done(self):
108
107
  if self.num_kvs_expected is None:
109
108
  return False
109
+ # Check for failure state
110
+ if self.num_kvs_expected == -1:
111
+ return True # Failed transfers are considered "done"
110
112
  return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
111
113
 
114
+ def is_failed(self):
115
+ return self.num_kvs_expected == -1
116
+
112
117
 
113
118
  class NixlKVManager(CommonKVManager):
114
119
  def __init__(
@@ -128,26 +133,133 @@ class NixlKVManager(CommonKVManager):
128
133
  "to run SGLang with NixlTransferEngine."
129
134
  ) from e
130
135
  self.agent = nixl_agent(str(uuid.uuid4()))
131
- self.local_ip = get_local_ip_auto()
132
- self.server_socket = zmq.Context().socket(zmq.PULL)
133
- if is_valid_ipv6_address(self.local_ip):
134
- self.server_socket.setsockopt(zmq.IPV6, 1)
135
136
  self.register_buffer_to_engine()
136
137
 
137
138
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
138
- self.request_status: Dict[int, KVPoll] = {}
139
- self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
140
- self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
141
139
  self._start_bootstrap_thread()
142
140
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
143
141
  self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
144
142
  TransferStatus
145
143
  )
144
+ self.heartbeat_failures = {}
145
+ self.session_pool = defaultdict(requests.Session)
146
+ self.session_pool_lock = threading.Lock()
147
+ self.addr_to_rooms_tracker = defaultdict(set)
148
+ self.connection_lock = threading.Lock()
149
+
150
+ # Heartbeat interval should be at least 2 seconds
151
+ self.heartbeat_interval = max(
152
+ float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
153
+ )
154
+ # Heartbeat failure should be at least 1
155
+ self.max_failures = max(
156
+ get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
157
+ )
158
+ self._start_heartbeat_checker_thread()
146
159
  else:
147
160
  raise ValueError(
148
161
  f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
149
162
  )
150
163
 
164
+ def _start_heartbeat_checker_thread(self):
165
+ """
166
+ Start the heartbeat checker thread for Decode worker.
167
+ TODO (smor): unite nixl heartbeat checker with mooncake's.
168
+ """
169
+
170
+ def heartbeat_checker():
171
+ while True:
172
+ time.sleep(self.heartbeat_interval)
173
+ with self.connection_lock:
174
+ addresses = list(self.prefill_dp_size_table.keys())
175
+
176
+ for bootstrap_addr in addresses:
177
+ session = None
178
+ try:
179
+ with self.session_pool_lock:
180
+ session = self.session_pool[bootstrap_addr]
181
+ response = session.get(
182
+ f"http://{bootstrap_addr}/health",
183
+ timeout=(2, 3),
184
+ headers={"Connection": "keep-alive"},
185
+ )
186
+ if response.status_code == 200:
187
+ self.heartbeat_failures[bootstrap_addr] = 0
188
+
189
+ current_rooms = self.addr_to_rooms_tracker[
190
+ bootstrap_addr
191
+ ].copy()
192
+
193
+ for bootstrap_room in current_rooms:
194
+ # Remove successful transfers from the tracker
195
+ if bootstrap_room not in self.transfer_statuses:
196
+ self.addr_to_rooms_tracker[bootstrap_addr].discard(
197
+ bootstrap_room
198
+ )
199
+ else:
200
+ logger.info(
201
+ f"Attempting to reconnect to {bootstrap_addr}..."
202
+ )
203
+ self.heartbeat_failures[bootstrap_addr] = (
204
+ self.heartbeat_failures.get(bootstrap_addr, 0) + 1
205
+ )
206
+ with self.session_pool_lock:
207
+ if bootstrap_addr in self.session_pool:
208
+ del self.session_pool[bootstrap_addr]
209
+ except Exception:
210
+ logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
211
+ self.heartbeat_failures[bootstrap_addr] = (
212
+ self.heartbeat_failures.get(bootstrap_addr, 0) + 1
213
+ )
214
+
215
+ if (
216
+ self.heartbeat_failures.get(bootstrap_addr, 0)
217
+ >= self.max_failures
218
+ ):
219
+ self._handle_node_failure(bootstrap_addr)
220
+ with self.session_pool_lock:
221
+ if bootstrap_addr in self.session_pool:
222
+ del self.session_pool[bootstrap_addr]
223
+
224
+ threading.Thread(target=heartbeat_checker, daemon=True).start()
225
+
226
+ def _handle_node_failure(self, failed_bootstrap_addr):
227
+ """Handle failure of a prefill node."""
228
+ with self.connection_lock:
229
+ keys_to_remove = [
230
+ k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
231
+ ]
232
+ for k in keys_to_remove:
233
+ del self.connection_pool[k]
234
+ if failed_bootstrap_addr in self.prefill_tp_size_table:
235
+ del self.prefill_tp_size_table[failed_bootstrap_addr]
236
+ if failed_bootstrap_addr in self.prefill_dp_size_table:
237
+ del self.prefill_dp_size_table[failed_bootstrap_addr]
238
+ if failed_bootstrap_addr in self.prefill_pp_size_table:
239
+ del self.prefill_pp_size_table[failed_bootstrap_addr]
240
+
241
+ possible_affected_rooms = self.addr_to_rooms_tracker.get(
242
+ failed_bootstrap_addr, []
243
+ )
244
+ if failed_bootstrap_addr in self.addr_to_rooms_tracker:
245
+ del self.addr_to_rooms_tracker[failed_bootstrap_addr]
246
+
247
+ # Mark all pending transfers associated with the failed node as failed
248
+ affected_rooms = []
249
+ for room in possible_affected_rooms:
250
+ if (
251
+ room in self.transfer_statuses
252
+ and not self.transfer_statuses[room].is_done()
253
+ ):
254
+ # Mark the transfer as failed by setting a special state
255
+ self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure
256
+ affected_rooms.append(room)
257
+
258
+ logger.error(
259
+ f"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), "
260
+ f"{len(affected_rooms)} transfers affected"
261
+ )
262
+
151
263
  def check_status(self, bootstrap_room: int):
152
264
  return self.request_status[bootstrap_room]
153
265
 
@@ -160,13 +272,16 @@ class NixlKVManager(CommonKVManager):
160
272
  self.request_status[bootstrap_room], status
161
273
  )
162
274
 
275
+ def record_failure(self, bootstrap_room: int, failure_reason: str):
276
+ pass
277
+
163
278
  def register_buffer_to_engine(self):
164
279
  kv_addrs = []
165
280
  for kv_data_ptr, kv_data_len in zip(
166
281
  self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
167
282
  ):
168
283
  kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
169
- self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
284
+ self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM")
170
285
  logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
171
286
  if not self.kv_descs:
172
287
  raise Exception("NIXL memory registration failed for kv tensors")
@@ -175,7 +290,7 @@ class NixlKVManager(CommonKVManager):
175
290
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
176
291
  ):
177
292
  aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
178
- self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
293
+ self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM")
179
294
  logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
180
295
  if not self.aux_descs:
181
296
  raise Exception("NIXL memory registration failed for aux tensors")
@@ -204,14 +319,44 @@ class NixlKVManager(CommonKVManager):
204
319
 
205
320
  logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
206
321
  # Make descs
207
- num_layers = len(self.kv_args.kv_data_ptrs)
322
+ if self.is_mla_backend:
323
+ src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
324
+ self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
325
+ )
326
+ kv_item_len = self.kv_args.kv_item_lens[0]
327
+ layers_params = [
328
+ (
329
+ src_kv_ptrs[layer_id],
330
+ dst_kv_ptrs[layer_id],
331
+ kv_item_len,
332
+ )
333
+ for layer_id in range(layers_current_pp_stage)
334
+ ]
335
+ else:
336
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
337
+ self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
338
+ )
339
+
340
+ kv_item_len = self.kv_args.kv_item_lens[0]
341
+ layers_params = [
342
+ (
343
+ src_k_ptrs[layer_id],
344
+ dst_k_ptrs[layer_id],
345
+ kv_item_len,
346
+ )
347
+ for layer_id in range(layers_current_pp_stage)
348
+ ] + [
349
+ (
350
+ src_v_ptrs[layer_id],
351
+ dst_v_ptrs[layer_id],
352
+ kv_item_len,
353
+ )
354
+ for layer_id in range(layers_current_pp_stage)
355
+ ]
356
+
208
357
  src_addrs = []
209
358
  dst_addrs = []
210
- for layer_id in range(num_layers):
211
- src_ptr = self.kv_args.kv_data_ptrs[layer_id]
212
- dst_ptr = dst_kv_ptrs[layer_id]
213
- item_len = self.kv_args.kv_item_lens[layer_id]
214
-
359
+ for src_ptr, dst_ptr, item_len in layers_params:
215
360
  for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
216
361
  src_addr = src_ptr + int(prefill_index[0]) * item_len
217
362
  dst_addr = dst_ptr + int(decode_index[0]) * item_len
@@ -222,8 +367,8 @@ class NixlKVManager(CommonKVManager):
222
367
  logger.debug(
223
368
  f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
224
369
  )
225
- src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
226
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
370
+ src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
371
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
227
372
  # Transfer data
228
373
  xfer_handle = self.agent.initialize_xfer(
229
374
  "WRITE",
@@ -239,6 +384,137 @@ class NixlKVManager(CommonKVManager):
239
384
  raise Exception("KVSender failed to post transfer")
240
385
  return xfer_handle
241
386
 
387
+ def send_kvcache_slice(
388
+ self,
389
+ peer_name: str,
390
+ prefill_kv_indices: npt.NDArray[np.int32],
391
+ dst_kv_ptrs: list[int],
392
+ dst_kv_indices: npt.NDArray[np.int32],
393
+ dst_gpu_id: int,
394
+ notif: str,
395
+ prefill_tp_size: int,
396
+ decode_tp_size: int,
397
+ decode_tp_rank: int,
398
+ dst_kv_item_len: int,
399
+ ):
400
+ # Get configuration from kv_args
401
+ local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size
402
+ dst_tp_rank_in_group = decode_tp_rank % decode_tp_size
403
+ num_kv_heads = self.kv_args.kv_head_num
404
+
405
+ # Calculate head distribution
406
+ src_heads_per_rank = num_kv_heads
407
+ dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size
408
+
409
+ src_kv_item_len = self.kv_args.kv_item_lens[0]
410
+ page_size = self.kv_args.page_size
411
+
412
+ bytes_per_head_slice_to_send = (
413
+ dst_kv_item_len // page_size // dst_heads_per_rank
414
+ )
415
+
416
+ # Determine which heads to send
417
+ if prefill_tp_size > decode_tp_size:
418
+ # Multiple prefill ranks to one decode rank
419
+ src_head_start_offset = 0
420
+ num_heads_to_send = src_heads_per_rank
421
+ dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
422
+ else:
423
+ # Send KVCache from 1 prefill instance to multiple decode instances
424
+ src_head_start_offset = (
425
+ dst_tp_rank_in_group * dst_heads_per_rank
426
+ ) % src_heads_per_rank
427
+ num_heads_to_send = dst_heads_per_rank
428
+ dst_head_start_offset = 0
429
+
430
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
431
+ self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
432
+ )
433
+ # Create transfer descriptors
434
+ src_addrs = []
435
+ dst_addrs = []
436
+
437
+ bytes_per_token_on_prefill = src_kv_item_len // page_size
438
+ bytes_per_token_on_decode = dst_kv_item_len // page_size
439
+
440
+ # Calculate precise byte offset and length for the sub-slice within the token
441
+ src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
442
+ dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
443
+ heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
444
+
445
+ src_dst_ptr_pairs = [
446
+ (
447
+ src_k_ptrs[layer_id],
448
+ dst_k_ptrs[layer_id],
449
+ )
450
+ for layer_id in range(layers_current_pp_stage)
451
+ ] + [
452
+ (
453
+ src_v_ptrs[layer_id],
454
+ dst_v_ptrs[layer_id],
455
+ )
456
+ for layer_id in range(layers_current_pp_stage)
457
+ ]
458
+
459
+ src_addrs = []
460
+ dst_addrs = []
461
+
462
+ # Calculate strides for a single token slot
463
+ bytes_per_token_on_prefill = src_kv_item_len // page_size
464
+ bytes_per_token_on_decode = dst_kv_item_len // page_size
465
+
466
+ for src_ptr, dst_ptr in src_dst_ptr_pairs:
467
+ for i in range(len(prefill_kv_indices)):
468
+ prefill_page_idx = int(prefill_kv_indices[i])
469
+ decode_page_idx = int(dst_kv_indices[i])
470
+
471
+ # Get the starting addresses for the current src and dst pages
472
+ src_page_start_addr = src_ptr + prefill_page_idx * src_kv_item_len
473
+ dst_page_start_addr = dst_ptr + decode_page_idx * dst_kv_item_len
474
+
475
+ # Iterate through each valid token slot within the current page
476
+ for token_slot_in_page in range(page_size):
477
+ # Calculate the start address of the current token slot
478
+ src_token_slot_start_addr = (
479
+ src_page_start_addr
480
+ + token_slot_in_page * bytes_per_token_on_prefill
481
+ )
482
+ dst_token_slot_start_addr = (
483
+ dst_page_start_addr
484
+ + token_slot_in_page * bytes_per_token_on_decode
485
+ )
486
+
487
+ # Calculate final src and dst addresses by applying head-slice offsets
488
+ src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
489
+ dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
490
+
491
+ src_addrs.append(
492
+ (
493
+ src_slice_addr,
494
+ heads_bytes_per_token_to_send,
495
+ self.kv_args.gpu_id,
496
+ )
497
+ )
498
+ dst_addrs.append(
499
+ (dst_slice_addr, heads_bytes_per_token_to_send, dst_gpu_id)
500
+ )
501
+
502
+ # Use NIXL agent for transfer
503
+ src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
504
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
505
+
506
+ xfer_handle = self.agent.initialize_xfer(
507
+ "WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii")
508
+ )
509
+ if not xfer_handle:
510
+ raise Exception("Failed to create sliced KV transfer")
511
+
512
+ state = self.agent.transfer(xfer_handle)
513
+ if state == "ERR":
514
+ raise Exception("Failed to post sliced KV transfer")
515
+
516
+ return xfer_handle
517
+
242
518
  def send_aux(
243
519
  self,
244
520
  peer_name: str,
@@ -247,16 +523,21 @@ class NixlKVManager(CommonKVManager):
247
523
  dst_aux_index: int,
248
524
  notif: str,
249
525
  ):
250
- # Make descs
251
- aux_item_len = self.kv_args.aux_item_lens[0]
252
- prefill_aux_addr = (
253
- self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
254
- )
255
- decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
256
- src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
257
- dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
258
- src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
259
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
526
+ src_addrs = []
527
+ dst_addrs = []
528
+
529
+ prefill_aux_ptrs = self.kv_args.aux_data_ptrs
530
+ prefill_aux_item_lens = self.kv_args.aux_item_lens
531
+
532
+ for i, _ in enumerate(dst_aux_ptrs):
533
+ length = prefill_aux_item_lens[i]
534
+ src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
535
+ dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
536
+ src_addrs.append((src_addr, length, 0))
537
+ dst_addrs.append((dst_addr, length, 0))
538
+
539
+ src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
540
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
260
541
  # Transfer data
261
542
  xfer_handle = self.agent.initialize_xfer(
262
543
  "WRITE",
@@ -296,17 +577,38 @@ class NixlKVManager(CommonKVManager):
296
577
  assert req.agent_name in self.decode_kv_args_table
297
578
 
298
579
  notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
299
- kv_xfer_handle = self.send_kvcache(
300
- req.agent_name,
301
- kv_indices,
302
- self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
303
- chunked_dst_kv_indice,
304
- self.decode_kv_args_table[req.agent_name].gpu_id,
305
- notif,
306
- )
580
+ decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
581
+
582
+ if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
583
+ kv_xfer_handle = self.send_kvcache(
584
+ req.agent_name,
585
+ kv_indices,
586
+ self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
587
+ chunked_dst_kv_indice,
588
+ self.decode_kv_args_table[req.agent_name].gpu_id,
589
+ notif,
590
+ )
591
+ else:
592
+ kv_xfer_handle = self.send_kvcache_slice(
593
+ req.agent_name,
594
+ kv_indices,
595
+ self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
596
+ chunked_dst_kv_indice,
597
+ self.decode_kv_args_table[req.agent_name].gpu_id,
598
+ notif,
599
+ prefill_tp_size=self.attn_tp_size,
600
+ decode_tp_size=decode_tp_size,
601
+ decode_tp_rank=self.decode_kv_args_table[
602
+ req.agent_name
603
+ ].decode_tp_rank,
604
+ dst_kv_item_len=self.decode_kv_args_table[
605
+ req.agent_name
606
+ ].dst_kv_item_len,
607
+ )
608
+
307
609
  handles.append(kv_xfer_handle)
308
610
  # Only the last chunk we need to send the aux data.
309
- if is_last:
611
+ if is_last and self.pp_group.is_last_rank:
310
612
  assert aux_index is not None
311
613
  aux_xfer_handle = self.send_aux(
312
614
  req.agent_name,
@@ -344,9 +646,6 @@ class NixlKVManager(CommonKVManager):
344
646
  return False
345
647
  return self.transfer_statuses[room].is_done()
346
648
 
347
- def _bind_server_socket(self):
348
- self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
349
-
350
649
  def _start_bootstrap_thread(self):
351
650
  self._bind_server_socket()
352
651
 
@@ -387,7 +686,7 @@ class NixlKVManager(CommonKVManager):
387
686
  threading.Thread(target=bootstrap_thread).start()
388
687
 
389
688
 
390
- class NixlKVSender(BaseKVSender):
689
+ class NixlKVSender(CommonKVSender):
391
690
 
392
691
  def __init__(
393
692
  self,
@@ -397,20 +696,10 @@ class NixlKVSender(BaseKVSender):
397
696
  dest_tp_ranks: List[int],
398
697
  pp_rank: int,
399
698
  ):
400
- self.kv_mgr = mgr
401
- self.bootstrap_room = bootstrap_room
402
- self.aux_index = None
403
- self.bootstrap_server_url = bootstrap_addr
699
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
404
700
  self.xfer_handles = []
405
701
  self.has_sent = False
406
702
  self.chunk_id = 0
407
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
408
- # inner state
409
- self.curr_idx = 0
410
-
411
- def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
412
- self.num_kv_indices = num_kv_indices
413
- self.aux_index = aux_index
414
703
 
415
704
  def send(
416
705
  self,
@@ -454,11 +743,17 @@ class NixlKVReceiver(CommonKVReceiver):
454
743
  mgr: NixlKVManager,
455
744
  bootstrap_addr: str,
456
745
  bootstrap_room: Optional[int] = None,
457
- data_parallel_rank: Optional[int] = None,
746
+ prefill_dp_rank: Optional[int] = None,
458
747
  ):
459
748
  self.started_transfer = False
460
749
  self.conclude_state = None
461
- super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
750
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
751
+
752
+ # Track this room with its bootstrap address for heartbeat monitoring
753
+ if hasattr(self.kv_mgr, "addr_to_rooms_tracker"):
754
+ self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(
755
+ self.bootstrap_room
756
+ )
462
757
 
463
758
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
464
759
  for bootstrap_info in self.bootstrap_infos:
@@ -494,9 +789,16 @@ class NixlKVReceiver(CommonKVReceiver):
494
789
 
495
790
  self.kv_mgr.update_transfer_status()
496
791
  if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
497
- self.conclude_state = KVPoll.Success
792
+ # Check if the transfer failed
793
+ if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed():
794
+ self.conclude_state = KVPoll.Failed
795
+ logger.error(
796
+ f"Transfer for room {self.bootstrap_room} failed due to node failure"
797
+ )
798
+ else:
799
+ self.conclude_state = KVPoll.Success
498
800
  del self.kv_mgr.transfer_statuses[self.bootstrap_room]
499
- return KVPoll.Success # type: ignore
801
+ return self.conclude_state # type: ignore
500
802
  return KVPoll.WaitingForInput # type: ignore
501
803
 
502
804
  def _register_kv_args(self):
@@ -521,6 +823,9 @@ class NixlKVReceiver(CommonKVReceiver):
521
823
  packed_kv_data_ptrs,
522
824
  packed_aux_data_ptrs,
523
825
  str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
826
+ str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"),
827
+ str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
828
+ str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"),
524
829
  ]
525
830
  )
526
831