sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,37 +1,30 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
4
3
  import dataclasses
5
4
  import logging
6
- import queue
7
- import socket
5
+ import os
8
6
  import struct
9
7
  import threading
8
+ import time
10
9
  import uuid
11
10
  from collections import defaultdict
12
- from functools import cache
13
- from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
11
+ from typing import Dict, List, Optional, Set
14
12
 
15
13
  import numpy as np
16
14
  import numpy.typing as npt
17
15
  import requests
18
- import zmq
19
- from aiohttp import web
20
16
 
21
- from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll
17
+ from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
22
18
  from sglang.srt.disaggregation.common.conn import (
23
19
  CommonKVBootstrapServer,
24
20
  CommonKVManager,
25
21
  CommonKVReceiver,
22
+ CommonKVSender,
26
23
  )
27
24
  from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
28
25
  from sglang.srt.disaggregation.utils import DisaggregationMode
29
26
  from sglang.srt.server_args import ServerArgs
30
- from sglang.srt.utils import (
31
- format_tcp_address,
32
- get_local_ip_auto,
33
- is_valid_ipv6_address,
34
- )
27
+ from sglang.srt.utils import get_int_env_var
35
28
 
36
29
  logger = logging.getLogger(__name__)
37
30
 
@@ -78,6 +71,9 @@ class KVArgsRegisterInfo:
78
71
  dst_kv_ptrs: list[int]
79
72
  dst_aux_ptrs: list[int]
80
73
  gpu_id: int
74
+ decode_tp_size: int
75
+ decode_tp_rank: int
76
+ dst_kv_item_len: int
81
77
 
82
78
  @classmethod
83
79
  def from_zmq(cls, msg: List[bytes]):
@@ -90,6 +86,9 @@ class KVArgsRegisterInfo:
90
86
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
91
87
  dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
92
88
  gpu_id=int(msg[7].decode("ascii")),
89
+ decode_tp_size=int(msg[8].decode("ascii")),
90
+ decode_tp_rank=int(msg[9].decode("ascii")),
91
+ dst_kv_item_len=int(msg[10].decode("ascii")),
93
92
  )
94
93
 
95
94
 
@@ -107,8 +106,14 @@ class TransferStatus:
107
106
  def is_done(self):
108
107
  if self.num_kvs_expected is None:
109
108
  return False
109
+ # Check for failure state
110
+ if self.num_kvs_expected == -1:
111
+ return True # Failed transfers are considered "done"
110
112
  return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
111
113
 
114
+ def is_failed(self):
115
+ return self.num_kvs_expected == -1
116
+
112
117
 
113
118
  class NixlKVManager(CommonKVManager):
114
119
  def __init__(
@@ -128,26 +133,133 @@ class NixlKVManager(CommonKVManager):
128
133
  "to run SGLang with NixlTransferEngine."
129
134
  ) from e
130
135
  self.agent = nixl_agent(str(uuid.uuid4()))
131
- self.local_ip = get_local_ip_auto()
132
- self.server_socket = zmq.Context().socket(zmq.PULL)
133
- if is_valid_ipv6_address(self.local_ip):
134
- self.server_socket.setsockopt(zmq.IPV6, 1)
135
136
  self.register_buffer_to_engine()
136
137
 
137
138
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
138
- self.request_status: Dict[int, KVPoll] = {}
139
- self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
140
- self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
141
139
  self._start_bootstrap_thread()
142
140
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
143
141
  self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
144
142
  TransferStatus
145
143
  )
144
+ self.heartbeat_failures = {}
145
+ self.session_pool = defaultdict(requests.Session)
146
+ self.session_pool_lock = threading.Lock()
147
+ self.addr_to_rooms_tracker = defaultdict(set)
148
+ self.connection_lock = threading.Lock()
149
+
150
+ # Heartbeat interval should be at least 2 seconds
151
+ self.heartbeat_interval = max(
152
+ float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
153
+ )
154
+ # Heartbeat failure should be at least 1
155
+ self.max_failures = max(
156
+ get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
157
+ )
158
+ self._start_heartbeat_checker_thread()
146
159
  else:
147
160
  raise ValueError(
148
161
  f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
149
162
  )
150
163
 
164
+ def _start_heartbeat_checker_thread(self):
165
+ """
166
+ Start the heartbeat checker thread for Decode worker.
167
+ TODO (smor): unite nixl heartbeat checker with mooncake's.
168
+ """
169
+
170
+ def heartbeat_checker():
171
+ while True:
172
+ time.sleep(self.heartbeat_interval)
173
+ with self.connection_lock:
174
+ addresses = list(self.prefill_dp_size_table.keys())
175
+
176
+ for bootstrap_addr in addresses:
177
+ session = None
178
+ try:
179
+ with self.session_pool_lock:
180
+ session = self.session_pool[bootstrap_addr]
181
+ response = session.get(
182
+ f"http://{bootstrap_addr}/health",
183
+ timeout=(2, 3),
184
+ headers={"Connection": "keep-alive"},
185
+ )
186
+ if response.status_code == 200:
187
+ self.heartbeat_failures[bootstrap_addr] = 0
188
+
189
+ current_rooms = self.addr_to_rooms_tracker[
190
+ bootstrap_addr
191
+ ].copy()
192
+
193
+ for bootstrap_room in current_rooms:
194
+ # Remove successful transfers from the tracker
195
+ if bootstrap_room not in self.transfer_statuses:
196
+ self.addr_to_rooms_tracker[bootstrap_addr].discard(
197
+ bootstrap_room
198
+ )
199
+ else:
200
+ logger.info(
201
+ f"Attempting to reconnect to {bootstrap_addr}..."
202
+ )
203
+ self.heartbeat_failures[bootstrap_addr] = (
204
+ self.heartbeat_failures.get(bootstrap_addr, 0) + 1
205
+ )
206
+ with self.session_pool_lock:
207
+ if bootstrap_addr in self.session_pool:
208
+ del self.session_pool[bootstrap_addr]
209
+ except Exception:
210
+ logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
211
+ self.heartbeat_failures[bootstrap_addr] = (
212
+ self.heartbeat_failures.get(bootstrap_addr, 0) + 1
213
+ )
214
+
215
+ if (
216
+ self.heartbeat_failures.get(bootstrap_addr, 0)
217
+ >= self.max_failures
218
+ ):
219
+ self._handle_node_failure(bootstrap_addr)
220
+ with self.session_pool_lock:
221
+ if bootstrap_addr in self.session_pool:
222
+ del self.session_pool[bootstrap_addr]
223
+
224
+ threading.Thread(target=heartbeat_checker, daemon=True).start()
225
+
226
+ def _handle_node_failure(self, failed_bootstrap_addr):
227
+ """Handle failure of a prefill node."""
228
+ with self.connection_lock:
229
+ keys_to_remove = [
230
+ k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
231
+ ]
232
+ for k in keys_to_remove:
233
+ del self.connection_pool[k]
234
+ if failed_bootstrap_addr in self.prefill_tp_size_table:
235
+ del self.prefill_tp_size_table[failed_bootstrap_addr]
236
+ if failed_bootstrap_addr in self.prefill_dp_size_table:
237
+ del self.prefill_dp_size_table[failed_bootstrap_addr]
238
+ if failed_bootstrap_addr in self.prefill_pp_size_table:
239
+ del self.prefill_pp_size_table[failed_bootstrap_addr]
240
+
241
+ possible_affected_rooms = self.addr_to_rooms_tracker.get(
242
+ failed_bootstrap_addr, []
243
+ )
244
+ if failed_bootstrap_addr in self.addr_to_rooms_tracker:
245
+ del self.addr_to_rooms_tracker[failed_bootstrap_addr]
246
+
247
+ # Mark all pending transfers associated with the failed node as failed
248
+ affected_rooms = []
249
+ for room in possible_affected_rooms:
250
+ if (
251
+ room in self.transfer_statuses
252
+ and not self.transfer_statuses[room].is_done()
253
+ ):
254
+ # Mark the transfer as failed by setting a special state
255
+ self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure
256
+ affected_rooms.append(room)
257
+
258
+ logger.error(
259
+ f"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), "
260
+ f"{len(affected_rooms)} transfers affected"
261
+ )
262
+
151
263
  def check_status(self, bootstrap_room: int):
152
264
  return self.request_status[bootstrap_room]
153
265
 
@@ -160,13 +272,16 @@ class NixlKVManager(CommonKVManager):
160
272
  self.request_status[bootstrap_room], status
161
273
  )
162
274
 
275
+ def record_failure(self, bootstrap_room: int, failure_reason: str):
276
+ pass
277
+
163
278
  def register_buffer_to_engine(self):
164
279
  kv_addrs = []
165
280
  for kv_data_ptr, kv_data_len in zip(
166
281
  self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
167
282
  ):
168
283
  kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
169
- self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
284
+ self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM")
170
285
  logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
171
286
  if not self.kv_descs:
172
287
  raise Exception("NIXL memory registration failed for kv tensors")
@@ -175,7 +290,7 @@ class NixlKVManager(CommonKVManager):
175
290
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
176
291
  ):
177
292
  aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
178
- self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
293
+ self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM")
179
294
  logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
180
295
  if not self.aux_descs:
181
296
  raise Exception("NIXL memory registration failed for aux tensors")
@@ -222,8 +337,8 @@ class NixlKVManager(CommonKVManager):
222
337
  logger.debug(
223
338
  f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
224
339
  )
225
- src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
226
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
340
+ src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
341
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
227
342
  # Transfer data
228
343
  xfer_handle = self.agent.initialize_xfer(
229
344
  "WRITE",
@@ -239,6 +354,140 @@ class NixlKVManager(CommonKVManager):
239
354
  raise Exception("KVSender failed to post transfer")
240
355
  return xfer_handle
241
356
 
357
+ def send_kvcache_slice(
358
+ self,
359
+ peer_name: str,
360
+ prefill_kv_indices: npt.NDArray[np.int32],
361
+ dst_kv_ptrs: list[int],
362
+ dst_kv_indices: npt.NDArray[np.int32],
363
+ dst_gpu_id: int,
364
+ notif: str,
365
+ prefill_tp_size: int,
366
+ decode_tp_size: int,
367
+ decode_tp_rank: int,
368
+ dst_kv_item_len: int,
369
+ ):
370
+ # Get configuration from kv_args
371
+ local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size
372
+ dst_tp_rank_in_group = decode_tp_rank % decode_tp_size
373
+ num_kv_heads = self.kv_args.kv_head_num
374
+
375
+ # Calculate head distribution
376
+ src_heads_per_rank = num_kv_heads
377
+ dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size
378
+
379
+ src_kv_item_len = self.kv_args.kv_item_lens[0]
380
+ page_size = self.kv_args.page_size
381
+
382
+ bytes_per_head_slice_to_send = (
383
+ dst_kv_item_len // page_size // dst_heads_per_rank
384
+ )
385
+
386
+ # Determine which heads to send
387
+ if prefill_tp_size > decode_tp_size:
388
+ # Multiple prefill ranks to one decode rank
389
+ src_head_start_offset = 0
390
+ num_heads_to_send = src_heads_per_rank
391
+ dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
392
+ else:
393
+ # Send KVCache from 1 prefill instance to multiple decode instances
394
+ src_head_start_offset = (
395
+ dst_tp_rank_in_group * dst_heads_per_rank
396
+ ) % src_heads_per_rank
397
+ num_heads_to_send = dst_heads_per_rank
398
+ dst_head_start_offset = 0
399
+
400
+ # Create transfer descriptors
401
+ src_addrs = []
402
+ dst_addrs = []
403
+
404
+ bytes_per_token_on_prefill = src_kv_item_len // page_size
405
+ bytes_per_token_on_decode = dst_kv_item_len // page_size
406
+
407
+ num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
408
+ src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
409
+ src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
410
+ dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
411
+ dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
412
+
413
+ # Calculate precise byte offset and length for the sub-slice within the token
414
+ src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
415
+ dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
416
+ heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
417
+
418
+ src_dst_ptr_pairs = [
419
+ (
420
+ src_k_ptrs[layer_id],
421
+ dst_k_ptrs[layer_id],
422
+ )
423
+ for layer_id in range(len(src_k_ptrs))
424
+ ] + [
425
+ (
426
+ src_v_ptrs[layer_id],
427
+ dst_v_ptrs[layer_id],
428
+ )
429
+ for layer_id in range(len(src_v_ptrs))
430
+ ]
431
+
432
+ src_addrs = []
433
+ dst_addrs = []
434
+
435
+ # Calculate strides for a single token slot
436
+ bytes_per_token_on_prefill = src_kv_item_len // page_size
437
+ bytes_per_token_on_decode = dst_kv_item_len // page_size
438
+
439
+ for src_ptr, dst_ptr in src_dst_ptr_pairs:
440
+ for i in range(len(prefill_kv_indices)):
441
+ prefill_page_idx = int(prefill_kv_indices[i])
442
+ decode_page_idx = int(dst_kv_indices[i])
443
+
444
+ # Get the starting addresses for the current src and dst pages
445
+ src_page_start_addr = src_ptr + prefill_page_idx * src_kv_item_len
446
+ dst_page_start_addr = dst_ptr + decode_page_idx * dst_kv_item_len
447
+
448
+ # Iterate through each valid token slot within the current page
449
+ for token_slot_in_page in range(page_size):
450
+ # Calculate the start address of the current token slot
451
+ src_token_slot_start_addr = (
452
+ src_page_start_addr
453
+ + token_slot_in_page * bytes_per_token_on_prefill
454
+ )
455
+ dst_token_slot_start_addr = (
456
+ dst_page_start_addr
457
+ + token_slot_in_page * bytes_per_token_on_decode
458
+ )
459
+
460
+ # Calculate final src and dst addresses by applying head-slice offsets
461
+ src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
462
+ dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
463
+
464
+ src_addrs.append(
465
+ (
466
+ src_slice_addr,
467
+ heads_bytes_per_token_to_send,
468
+ self.kv_args.gpu_id,
469
+ )
470
+ )
471
+ dst_addrs.append(
472
+ (dst_slice_addr, heads_bytes_per_token_to_send, dst_gpu_id)
473
+ )
474
+
475
+ # Use NIXL agent for transfer
476
+ src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
477
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
478
+
479
+ xfer_handle = self.agent.initialize_xfer(
480
+ "WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii")
481
+ )
482
+ if not xfer_handle:
483
+ raise Exception("Failed to create sliced KV transfer")
484
+
485
+ state = self.agent.transfer(xfer_handle)
486
+ if state == "ERR":
487
+ raise Exception("Failed to post sliced KV transfer")
488
+
489
+ return xfer_handle
490
+
242
491
  def send_aux(
243
492
  self,
244
493
  peer_name: str,
@@ -255,8 +504,8 @@ class NixlKVManager(CommonKVManager):
255
504
  decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
256
505
  src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
257
506
  dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
258
- src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
259
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
507
+ src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
508
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
260
509
  # Transfer data
261
510
  xfer_handle = self.agent.initialize_xfer(
262
511
  "WRITE",
@@ -296,14 +545,35 @@ class NixlKVManager(CommonKVManager):
296
545
  assert req.agent_name in self.decode_kv_args_table
297
546
 
298
547
  notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
299
- kv_xfer_handle = self.send_kvcache(
300
- req.agent_name,
301
- kv_indices,
302
- self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
303
- chunked_dst_kv_indice,
304
- self.decode_kv_args_table[req.agent_name].gpu_id,
305
- notif,
306
- )
548
+ decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
549
+
550
+ if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
551
+ kv_xfer_handle = self.send_kvcache(
552
+ req.agent_name,
553
+ kv_indices,
554
+ self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
555
+ chunked_dst_kv_indice,
556
+ self.decode_kv_args_table[req.agent_name].gpu_id,
557
+ notif,
558
+ )
559
+ else:
560
+ kv_xfer_handle = self.send_kvcache_slice(
561
+ req.agent_name,
562
+ kv_indices,
563
+ self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
564
+ chunked_dst_kv_indice,
565
+ self.decode_kv_args_table[req.agent_name].gpu_id,
566
+ notif,
567
+ prefill_tp_size=self.attn_tp_size,
568
+ decode_tp_size=decode_tp_size,
569
+ decode_tp_rank=self.decode_kv_args_table[
570
+ req.agent_name
571
+ ].decode_tp_rank,
572
+ dst_kv_item_len=self.decode_kv_args_table[
573
+ req.agent_name
574
+ ].dst_kv_item_len,
575
+ )
576
+
307
577
  handles.append(kv_xfer_handle)
308
578
  # Only the last chunk we need to send the aux data.
309
579
  if is_last:
@@ -344,9 +614,6 @@ class NixlKVManager(CommonKVManager):
344
614
  return False
345
615
  return self.transfer_statuses[room].is_done()
346
616
 
347
- def _bind_server_socket(self):
348
- self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
349
-
350
617
  def _start_bootstrap_thread(self):
351
618
  self._bind_server_socket()
352
619
 
@@ -387,7 +654,7 @@ class NixlKVManager(CommonKVManager):
387
654
  threading.Thread(target=bootstrap_thread).start()
388
655
 
389
656
 
390
- class NixlKVSender(BaseKVSender):
657
+ class NixlKVSender(CommonKVSender):
391
658
 
392
659
  def __init__(
393
660
  self,
@@ -397,20 +664,10 @@ class NixlKVSender(BaseKVSender):
397
664
  dest_tp_ranks: List[int],
398
665
  pp_rank: int,
399
666
  ):
400
- self.kv_mgr = mgr
401
- self.bootstrap_room = bootstrap_room
402
- self.aux_index = None
403
- self.bootstrap_server_url = bootstrap_addr
667
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
404
668
  self.xfer_handles = []
405
669
  self.has_sent = False
406
670
  self.chunk_id = 0
407
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
408
- # inner state
409
- self.curr_idx = 0
410
-
411
- def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
412
- self.num_kv_indices = num_kv_indices
413
- self.aux_index = aux_index
414
671
 
415
672
  def send(
416
673
  self,
@@ -454,11 +711,17 @@ class NixlKVReceiver(CommonKVReceiver):
454
711
  mgr: NixlKVManager,
455
712
  bootstrap_addr: str,
456
713
  bootstrap_room: Optional[int] = None,
457
- data_parallel_rank: Optional[int] = None,
714
+ prefill_dp_rank: Optional[int] = None,
458
715
  ):
459
716
  self.started_transfer = False
460
717
  self.conclude_state = None
461
- super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
718
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
719
+
720
+ # Track this room with its bootstrap address for heartbeat monitoring
721
+ if hasattr(self.kv_mgr, "addr_to_rooms_tracker"):
722
+ self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(
723
+ self.bootstrap_room
724
+ )
462
725
 
463
726
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
464
727
  for bootstrap_info in self.bootstrap_infos:
@@ -494,9 +757,16 @@ class NixlKVReceiver(CommonKVReceiver):
494
757
 
495
758
  self.kv_mgr.update_transfer_status()
496
759
  if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
497
- self.conclude_state = KVPoll.Success
760
+ # Check if the transfer failed
761
+ if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed():
762
+ self.conclude_state = KVPoll.Failed
763
+ logger.error(
764
+ f"Transfer for room {self.bootstrap_room} failed due to node failure"
765
+ )
766
+ else:
767
+ self.conclude_state = KVPoll.Success
498
768
  del self.kv_mgr.transfer_statuses[self.bootstrap_room]
499
- return KVPoll.Success # type: ignore
769
+ return self.conclude_state # type: ignore
500
770
  return KVPoll.WaitingForInput # type: ignore
501
771
 
502
772
  def _register_kv_args(self):
@@ -521,6 +791,9 @@ class NixlKVReceiver(CommonKVReceiver):
521
791
  packed_kv_data_ptrs,
522
792
  packed_aux_data_ptrs,
523
793
  str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
794
+ str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"),
795
+ str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
796
+ str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"),
524
797
  ]
525
798
  )
526
799