sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,37 +1,30 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
4
3
  import dataclasses
5
4
  import logging
6
- import queue
7
- import socket
5
+ import os
8
6
  import struct
9
7
  import threading
8
+ import time
10
9
  import uuid
11
10
  from collections import defaultdict
12
- from functools import cache
13
- from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
11
+ from typing import Dict, List, Optional, Set
14
12
 
15
13
  import numpy as np
16
14
  import numpy.typing as npt
17
15
  import requests
18
- import zmq
19
- from aiohttp import web
20
16
 
21
- from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll
17
+ from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
22
18
  from sglang.srt.disaggregation.common.conn import (
23
19
  CommonKVBootstrapServer,
24
20
  CommonKVManager,
25
21
  CommonKVReceiver,
22
+ CommonKVSender,
26
23
  )
27
24
  from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
28
25
  from sglang.srt.disaggregation.utils import DisaggregationMode
29
26
  from sglang.srt.server_args import ServerArgs
30
- from sglang.srt.utils import (
31
- format_tcp_address,
32
- get_local_ip_auto,
33
- is_valid_ipv6_address,
34
- )
27
+ from sglang.srt.utils import get_int_env_var
35
28
 
36
29
  logger = logging.getLogger(__name__)
37
30
 
@@ -113,8 +106,14 @@ class TransferStatus:
113
106
  def is_done(self):
114
107
  if self.num_kvs_expected is None:
115
108
  return False
109
+ # Check for failure state
110
+ if self.num_kvs_expected == -1:
111
+ return True # Failed transfers are considered "done"
116
112
  return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
117
113
 
114
+ def is_failed(self):
115
+ return self.num_kvs_expected == -1
116
+
118
117
 
119
118
  class NixlKVManager(CommonKVManager):
120
119
  def __init__(
@@ -134,26 +133,133 @@ class NixlKVManager(CommonKVManager):
134
133
  "to run SGLang with NixlTransferEngine."
135
134
  ) from e
136
135
  self.agent = nixl_agent(str(uuid.uuid4()))
137
- self.local_ip = get_local_ip_auto()
138
- self.server_socket = zmq.Context().socket(zmq.PULL)
139
- if is_valid_ipv6_address(self.local_ip):
140
- self.server_socket.setsockopt(zmq.IPV6, 1)
141
136
  self.register_buffer_to_engine()
142
137
 
143
138
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
144
- self.request_status: Dict[int, KVPoll] = {}
145
- self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
146
- self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
147
139
  self._start_bootstrap_thread()
148
140
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
149
141
  self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
150
142
  TransferStatus
151
143
  )
144
+ self.heartbeat_failures = {}
145
+ self.session_pool = defaultdict(requests.Session)
146
+ self.session_pool_lock = threading.Lock()
147
+ self.addr_to_rooms_tracker = defaultdict(set)
148
+ self.connection_lock = threading.Lock()
149
+
150
+ # Heartbeat interval should be at least 2 seconds
151
+ self.heartbeat_interval = max(
152
+ float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
153
+ )
154
+ # Heartbeat failure should be at least 1
155
+ self.max_failures = max(
156
+ get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
157
+ )
158
+ self._start_heartbeat_checker_thread()
152
159
  else:
153
160
  raise ValueError(
154
161
  f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
155
162
  )
156
163
 
164
+ def _start_heartbeat_checker_thread(self):
165
+ """
166
+ Start the heartbeat checker thread for Decode worker.
167
+ TODO (smor): unite nixl heartbeat checker with mooncake's.
168
+ """
169
+
170
+ def heartbeat_checker():
171
+ while True:
172
+ time.sleep(self.heartbeat_interval)
173
+ with self.connection_lock:
174
+ addresses = list(self.prefill_dp_size_table.keys())
175
+
176
+ for bootstrap_addr in addresses:
177
+ session = None
178
+ try:
179
+ with self.session_pool_lock:
180
+ session = self.session_pool[bootstrap_addr]
181
+ response = session.get(
182
+ f"http://{bootstrap_addr}/health",
183
+ timeout=(2, 3),
184
+ headers={"Connection": "keep-alive"},
185
+ )
186
+ if response.status_code == 200:
187
+ self.heartbeat_failures[bootstrap_addr] = 0
188
+
189
+ current_rooms = self.addr_to_rooms_tracker[
190
+ bootstrap_addr
191
+ ].copy()
192
+
193
+ for bootstrap_room in current_rooms:
194
+ # Remove successful transfers from the tracker
195
+ if bootstrap_room not in self.transfer_statuses:
196
+ self.addr_to_rooms_tracker[bootstrap_addr].discard(
197
+ bootstrap_room
198
+ )
199
+ else:
200
+ logger.info(
201
+ f"Attempting to reconnect to {bootstrap_addr}..."
202
+ )
203
+ self.heartbeat_failures[bootstrap_addr] = (
204
+ self.heartbeat_failures.get(bootstrap_addr, 0) + 1
205
+ )
206
+ with self.session_pool_lock:
207
+ if bootstrap_addr in self.session_pool:
208
+ del self.session_pool[bootstrap_addr]
209
+ except Exception:
210
+ logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
211
+ self.heartbeat_failures[bootstrap_addr] = (
212
+ self.heartbeat_failures.get(bootstrap_addr, 0) + 1
213
+ )
214
+
215
+ if (
216
+ self.heartbeat_failures.get(bootstrap_addr, 0)
217
+ >= self.max_failures
218
+ ):
219
+ self._handle_node_failure(bootstrap_addr)
220
+ with self.session_pool_lock:
221
+ if bootstrap_addr in self.session_pool:
222
+ del self.session_pool[bootstrap_addr]
223
+
224
+ threading.Thread(target=heartbeat_checker, daemon=True).start()
225
+
226
+ def _handle_node_failure(self, failed_bootstrap_addr):
227
+ """Handle failure of a prefill node."""
228
+ with self.connection_lock:
229
+ keys_to_remove = [
230
+ k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
231
+ ]
232
+ for k in keys_to_remove:
233
+ del self.connection_pool[k]
234
+ if failed_bootstrap_addr in self.prefill_tp_size_table:
235
+ del self.prefill_tp_size_table[failed_bootstrap_addr]
236
+ if failed_bootstrap_addr in self.prefill_dp_size_table:
237
+ del self.prefill_dp_size_table[failed_bootstrap_addr]
238
+ if failed_bootstrap_addr in self.prefill_pp_size_table:
239
+ del self.prefill_pp_size_table[failed_bootstrap_addr]
240
+
241
+ possible_affected_rooms = self.addr_to_rooms_tracker.get(
242
+ failed_bootstrap_addr, []
243
+ )
244
+ if failed_bootstrap_addr in self.addr_to_rooms_tracker:
245
+ del self.addr_to_rooms_tracker[failed_bootstrap_addr]
246
+
247
+ # Mark all pending transfers associated with the failed node as failed
248
+ affected_rooms = []
249
+ for room in possible_affected_rooms:
250
+ if (
251
+ room in self.transfer_statuses
252
+ and not self.transfer_statuses[room].is_done()
253
+ ):
254
+ # Mark the transfer as failed by setting a special state
255
+ self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure
256
+ affected_rooms.append(room)
257
+
258
+ logger.error(
259
+ f"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), "
260
+ f"{len(affected_rooms)} transfers affected"
261
+ )
262
+
157
263
  def check_status(self, bootstrap_room: int):
158
264
  return self.request_status[bootstrap_room]
159
265
 
@@ -166,6 +272,9 @@ class NixlKVManager(CommonKVManager):
166
272
  self.request_status[bootstrap_room], status
167
273
  )
168
274
 
275
+ def record_failure(self, bootstrap_room: int, failure_reason: str):
276
+ pass
277
+
169
278
  def register_buffer_to_engine(self):
170
279
  kv_addrs = []
171
280
  for kv_data_ptr, kv_data_len in zip(
@@ -210,14 +319,44 @@ class NixlKVManager(CommonKVManager):
210
319
 
211
320
  logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
212
321
  # Make descs
213
- num_layers = len(self.kv_args.kv_data_ptrs)
322
+ if self.is_mla_backend:
323
+ src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
324
+ self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
325
+ )
326
+ kv_item_len = self.kv_args.kv_item_lens[0]
327
+ layers_params = [
328
+ (
329
+ src_kv_ptrs[layer_id],
330
+ dst_kv_ptrs[layer_id],
331
+ kv_item_len,
332
+ )
333
+ for layer_id in range(layers_current_pp_stage)
334
+ ]
335
+ else:
336
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
337
+ self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
338
+ )
339
+
340
+ kv_item_len = self.kv_args.kv_item_lens[0]
341
+ layers_params = [
342
+ (
343
+ src_k_ptrs[layer_id],
344
+ dst_k_ptrs[layer_id],
345
+ kv_item_len,
346
+ )
347
+ for layer_id in range(layers_current_pp_stage)
348
+ ] + [
349
+ (
350
+ src_v_ptrs[layer_id],
351
+ dst_v_ptrs[layer_id],
352
+ kv_item_len,
353
+ )
354
+ for layer_id in range(layers_current_pp_stage)
355
+ ]
356
+
214
357
  src_addrs = []
215
358
  dst_addrs = []
216
- for layer_id in range(num_layers):
217
- src_ptr = self.kv_args.kv_data_ptrs[layer_id]
218
- dst_ptr = dst_kv_ptrs[layer_id]
219
- item_len = self.kv_args.kv_item_lens[layer_id]
220
-
359
+ for src_ptr, dst_ptr, item_len in layers_params:
221
360
  for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
222
361
  src_addr = src_ptr + int(prefill_index[0]) * item_len
223
362
  dst_addr = dst_ptr + int(decode_index[0]) * item_len
@@ -288,6 +427,9 @@ class NixlKVManager(CommonKVManager):
288
427
  num_heads_to_send = dst_heads_per_rank
289
428
  dst_head_start_offset = 0
290
429
 
430
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
431
+ self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
432
+ )
291
433
  # Create transfer descriptors
292
434
  src_addrs = []
293
435
  dst_addrs = []
@@ -295,12 +437,6 @@ class NixlKVManager(CommonKVManager):
295
437
  bytes_per_token_on_prefill = src_kv_item_len // page_size
296
438
  bytes_per_token_on_decode = dst_kv_item_len // page_size
297
439
 
298
- num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
299
- src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
300
- src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
301
- dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
302
- dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
303
-
304
440
  # Calculate precise byte offset and length for the sub-slice within the token
305
441
  src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
306
442
  dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
@@ -311,13 +447,13 @@ class NixlKVManager(CommonKVManager):
311
447
  src_k_ptrs[layer_id],
312
448
  dst_k_ptrs[layer_id],
313
449
  )
314
- for layer_id in range(len(src_k_ptrs))
450
+ for layer_id in range(layers_current_pp_stage)
315
451
  ] + [
316
452
  (
317
453
  src_v_ptrs[layer_id],
318
454
  dst_v_ptrs[layer_id],
319
455
  )
320
- for layer_id in range(len(src_v_ptrs))
456
+ for layer_id in range(layers_current_pp_stage)
321
457
  ]
322
458
 
323
459
  src_addrs = []
@@ -387,14 +523,19 @@ class NixlKVManager(CommonKVManager):
387
523
  dst_aux_index: int,
388
524
  notif: str,
389
525
  ):
390
- # Make descs
391
- aux_item_len = self.kv_args.aux_item_lens[0]
392
- prefill_aux_addr = (
393
- self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
394
- )
395
- decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
396
- src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
397
- dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
526
+ src_addrs = []
527
+ dst_addrs = []
528
+
529
+ prefill_aux_ptrs = self.kv_args.aux_data_ptrs
530
+ prefill_aux_item_lens = self.kv_args.aux_item_lens
531
+
532
+ for i, _ in enumerate(dst_aux_ptrs):
533
+ length = prefill_aux_item_lens[i]
534
+ src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
535
+ dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
536
+ src_addrs.append((src_addr, length, 0))
537
+ dst_addrs.append((dst_addr, length, 0))
538
+
398
539
  src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
399
540
  dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
400
541
  # Transfer data
@@ -438,7 +579,7 @@ class NixlKVManager(CommonKVManager):
438
579
  notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
439
580
  decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
440
581
 
441
- if decode_tp_size == self.tp_size:
582
+ if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
442
583
  kv_xfer_handle = self.send_kvcache(
443
584
  req.agent_name,
444
585
  kv_indices,
@@ -455,7 +596,7 @@ class NixlKVManager(CommonKVManager):
455
596
  chunked_dst_kv_indice,
456
597
  self.decode_kv_args_table[req.agent_name].gpu_id,
457
598
  notif,
458
- prefill_tp_size=self.tp_size,
599
+ prefill_tp_size=self.attn_tp_size,
459
600
  decode_tp_size=decode_tp_size,
460
601
  decode_tp_rank=self.decode_kv_args_table[
461
602
  req.agent_name
@@ -467,7 +608,7 @@ class NixlKVManager(CommonKVManager):
467
608
 
468
609
  handles.append(kv_xfer_handle)
469
610
  # Only the last chunk we need to send the aux data.
470
- if is_last:
611
+ if is_last and self.pp_group.is_last_rank:
471
612
  assert aux_index is not None
472
613
  aux_xfer_handle = self.send_aux(
473
614
  req.agent_name,
@@ -505,9 +646,6 @@ class NixlKVManager(CommonKVManager):
505
646
  return False
506
647
  return self.transfer_statuses[room].is_done()
507
648
 
508
- def _bind_server_socket(self):
509
- self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
510
-
511
649
  def _start_bootstrap_thread(self):
512
650
  self._bind_server_socket()
513
651
 
@@ -548,7 +686,7 @@ class NixlKVManager(CommonKVManager):
548
686
  threading.Thread(target=bootstrap_thread).start()
549
687
 
550
688
 
551
- class NixlKVSender(BaseKVSender):
689
+ class NixlKVSender(CommonKVSender):
552
690
 
553
691
  def __init__(
554
692
  self,
@@ -558,24 +696,15 @@ class NixlKVSender(BaseKVSender):
558
696
  dest_tp_ranks: List[int],
559
697
  pp_rank: int,
560
698
  ):
561
- self.kv_mgr = mgr
562
- self.bootstrap_room = bootstrap_room
563
- self.aux_index = None
564
- self.bootstrap_server_url = bootstrap_addr
699
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
565
700
  self.xfer_handles = []
566
701
  self.has_sent = False
567
702
  self.chunk_id = 0
568
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
569
- # inner state
570
- self.curr_idx = 0
571
-
572
- def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
573
- self.num_kv_indices = num_kv_indices
574
- self.aux_index = aux_index
575
703
 
576
704
  def send(
577
705
  self,
578
706
  kv_indices: npt.NDArray[np.int32],
707
+ state_indices: Optional[List[int]] = None,
579
708
  ):
580
709
  index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
581
710
  self.curr_idx += len(kv_indices)
@@ -621,7 +750,25 @@ class NixlKVReceiver(CommonKVReceiver):
621
750
  self.conclude_state = None
622
751
  super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
623
752
 
624
- def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
753
+ # Track this room with its bootstrap address for heartbeat monitoring
754
+ if hasattr(self.kv_mgr, "addr_to_rooms_tracker"):
755
+ self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(
756
+ self.bootstrap_room
757
+ )
758
+
759
+ def init(
760
+ self,
761
+ kv_indices: npt.NDArray[np.int32],
762
+ aux_index: Optional[int] = None,
763
+ state_indices: Optional[List[int]] = None,
764
+ ):
765
+ if self.bootstrap_infos is None:
766
+ logger.error(
767
+ f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
768
+ )
769
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
770
+ return
771
+
625
772
  for bootstrap_info in self.bootstrap_infos:
626
773
  logger.debug(
627
774
  f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
@@ -655,9 +802,16 @@ class NixlKVReceiver(CommonKVReceiver):
655
802
 
656
803
  self.kv_mgr.update_transfer_status()
657
804
  if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
658
- self.conclude_state = KVPoll.Success
805
+ # Check if the transfer failed
806
+ if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed():
807
+ self.conclude_state = KVPoll.Failed
808
+ logger.error(
809
+ f"Transfer for room {self.bootstrap_room} failed due to node failure"
810
+ )
811
+ else:
812
+ self.conclude_state = KVPoll.Success
659
813
  del self.kv_mgr.transfer_statuses[self.bootstrap_room]
660
- return KVPoll.Success # type: ignore
814
+ return self.conclude_state # type: ignore
661
815
  return KVPoll.WaitingForInput # type: ignore
662
816
 
663
817
  def _register_kv_args(self):