sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ import struct
9
9
  import threading
10
10
  import time
11
11
  from collections import defaultdict
12
- from typing import Dict, List, Optional, Tuple
12
+ from typing import Dict, List, Optional, Set, Tuple
13
13
 
14
14
  import numpy as np
15
15
  import numpy.typing as npt
@@ -58,6 +58,7 @@ class TransferKVChunk:
58
58
  index_slice: slice
59
59
  is_last: bool
60
60
  prefill_aux_index: Optional[int]
61
+ state_indices: Optional[List[int]]
61
62
 
62
63
 
63
64
  # decode
@@ -69,6 +70,7 @@ class TransferInfo:
69
70
  mooncake_session_id: str
70
71
  dst_kv_indices: npt.NDArray[np.int32]
71
72
  dst_aux_index: int
73
+ dst_state_indices: List[int]
72
74
  required_dst_info_num: int
73
75
  is_dummy: bool
74
76
 
@@ -78,9 +80,14 @@ class TransferInfo:
78
80
  is_dummy = True
79
81
  dst_kv_indices = np.array([], dtype=np.int32)
80
82
  dst_aux_index = None
83
+ dst_state_indices = []
81
84
  else:
82
85
  dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
83
86
  dst_aux_index = int(msg[5].decode("ascii"))
87
+ if msg[6] == b"":
88
+ dst_state_indices = []
89
+ else:
90
+ dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32))
84
91
  is_dummy = False
85
92
  return cls(
86
93
  room=int(msg[0].decode("ascii")),
@@ -89,7 +96,8 @@ class TransferInfo:
89
96
  mooncake_session_id=msg[3].decode("ascii"),
90
97
  dst_kv_indices=dst_kv_indices,
91
98
  dst_aux_index=dst_aux_index,
92
- required_dst_info_num=int(msg[6].decode("ascii")),
99
+ dst_state_indices=dst_state_indices,
100
+ required_dst_info_num=int(msg[7].decode("ascii")),
93
101
  is_dummy=is_dummy,
94
102
  )
95
103
 
@@ -103,6 +111,7 @@ class KVArgsRegisterInfo:
103
111
  mooncake_session_id: str
104
112
  dst_kv_ptrs: list[int]
105
113
  dst_aux_ptrs: list[int]
114
+ dst_state_data_ptrs: list[int]
106
115
  dst_tp_rank: int
107
116
  dst_attn_tp_size: int
108
117
  dst_kv_item_len: int
@@ -116,9 +125,10 @@ class KVArgsRegisterInfo:
116
125
  mooncake_session_id=msg[3].decode("ascii"),
117
126
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
118
127
  dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
119
- dst_tp_rank=int(msg[6].decode("ascii")),
120
- dst_attn_tp_size=int(msg[7].decode("ascii")),
121
- dst_kv_item_len=int(msg[8].decode("ascii")),
128
+ dst_state_data_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
129
+ dst_tp_rank=int(msg[7].decode("ascii")),
130
+ dst_attn_tp_size=int(msg[8].decode("ascii")),
131
+ dst_kv_item_len=int(msg[9].decode("ascii")),
122
132
  )
123
133
 
124
134
 
@@ -164,7 +174,7 @@ class MooncakeKVManager(CommonKVManager):
164
174
  cpu_count = os.cpu_count()
165
175
  transfer_thread_pool_size = get_int_env_var(
166
176
  "SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
167
- min(max(4, int(0.75 * cpu_count) // 8), 12),
177
+ min(max(4, int(0.5 * cpu_count) // 8), 12),
168
178
  )
169
179
  transfer_queue_size = get_int_env_var("SGLANG_DISAGGREGATION_QUEUE_SIZE", 4)
170
180
  self.transfer_queues: List[FastQueue] = [
@@ -239,6 +249,12 @@ class MooncakeKVManager(CommonKVManager):
239
249
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
240
250
  )
241
251
 
252
+ # Batch register state/extra pool data buffers
253
+ if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
254
+ self.engine.batch_register(
255
+ self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
256
+ )
257
+
242
258
  def _transfer_data(self, mooncake_session_id, transfer_blocks):
243
259
  if not transfer_blocks:
244
260
  return 0
@@ -248,17 +264,23 @@ class MooncakeKVManager(CommonKVManager):
248
264
  mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
249
265
  )
250
266
 
251
- def send_kvcache(
267
+ def _send_kvcache_generic(
252
268
  self,
253
269
  mooncake_session_id: str,
254
- prefill_kv_indices: npt.NDArray[np.int32],
255
- dst_kv_ptrs: list[int],
256
- dst_kv_indices: npt.NDArray[np.int32],
270
+ src_data_ptrs: list[int],
271
+ dst_data_ptrs: list[int],
272
+ item_lens: list[int],
273
+ prefill_data_indices: npt.NDArray[np.int32],
274
+ dst_data_indices: npt.NDArray[np.int32],
257
275
  executor: concurrent.futures.ThreadPoolExecutor,
258
- ):
259
- # Group by indices
276
+ ) -> int:
277
+ """
278
+ Generic KV cache transfer supporting both MHA and MLA architectures.
279
+ This method is used by both send_kvcache (full pool) and maybe_send_extra.
280
+ """
281
+ # Group by indices for optimization
260
282
  prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
261
- prefill_kv_indices, dst_kv_indices
283
+ prefill_data_indices, dst_data_indices
262
284
  )
263
285
 
264
286
  layers_params = None
@@ -266,9 +288,9 @@ class MooncakeKVManager(CommonKVManager):
266
288
  # pp is not supported on the decode side yet
267
289
  if self.is_mla_backend:
268
290
  src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
269
- self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
291
+ self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
270
292
  )
271
- kv_item_len = self.kv_args.kv_item_lens[0]
293
+ kv_item_len = item_lens[0]
272
294
  layers_params = [
273
295
  (
274
296
  src_kv_ptrs[layer_id],
@@ -279,9 +301,9 @@ class MooncakeKVManager(CommonKVManager):
279
301
  ]
280
302
  else:
281
303
  src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
282
- self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
304
+ self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
283
305
  )
284
- kv_item_len = self.kv_args.kv_item_lens[0]
306
+ kv_item_len = item_lens[0]
285
307
  layers_params = [
286
308
  (
287
309
  src_k_ptrs[layer_id],
@@ -345,6 +367,24 @@ class MooncakeKVManager(CommonKVManager):
345
367
 
346
368
  return 0
347
369
 
370
+ def send_kvcache(
371
+ self,
372
+ mooncake_session_id: str,
373
+ prefill_kv_indices: npt.NDArray[np.int32],
374
+ dst_kv_ptrs: list[int],
375
+ dst_kv_indices: npt.NDArray[np.int32],
376
+ executor: concurrent.futures.ThreadPoolExecutor,
377
+ ):
378
+ return self._send_kvcache_generic(
379
+ mooncake_session_id=mooncake_session_id,
380
+ src_data_ptrs=self.kv_args.kv_data_ptrs,
381
+ dst_data_ptrs=dst_kv_ptrs,
382
+ item_lens=self.kv_args.kv_item_lens,
383
+ prefill_data_indices=prefill_kv_indices,
384
+ dst_data_indices=dst_kv_indices,
385
+ executor=executor,
386
+ )
387
+
348
388
  def send_kvcache_slice(
349
389
  self,
350
390
  mooncake_session_id: str,
@@ -593,6 +633,59 @@ class MooncakeKVManager(CommonKVManager):
593
633
  f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
594
634
  )
595
635
 
636
+ def maybe_send_extra(
637
+ self,
638
+ req: TransferInfo,
639
+ prefill_state_indices: list[int],
640
+ dst_state_data_ptrs: list[int],
641
+ executor: concurrent.futures.ThreadPoolExecutor,
642
+ ):
643
+ """Send state or extra pool data with type-specific handling."""
644
+ state_type = getattr(self.kv_args, "state_type", "none")
645
+
646
+ if state_type == "mamba":
647
+ return self._send_mamba_state(
648
+ req,
649
+ prefill_state_indices,
650
+ dst_state_data_ptrs,
651
+ )
652
+ elif state_type in ["swa", "nsa"]:
653
+ # Reuse _send_kvcache_generic interface to send extra pool data
654
+ prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32)
655
+ dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32)
656
+ return self._send_kvcache_generic(
657
+ mooncake_session_id=req.mooncake_session_id,
658
+ src_data_ptrs=self.kv_args.state_data_ptrs,
659
+ dst_data_ptrs=dst_state_data_ptrs,
660
+ item_lens=self.kv_args.state_item_lens,
661
+ prefill_data_indices=prefill_state_indices,
662
+ dst_data_indices=dst_state_indices,
663
+ executor=executor,
664
+ )
665
+ else:
666
+ return 0
667
+
668
+ def _send_mamba_state(
669
+ self,
670
+ req: TransferInfo,
671
+ prefill_mamba_index: list[int],
672
+ dst_state_data_ptrs: list[int],
673
+ ):
674
+ """Transfer Mamba states."""
675
+ assert len(prefill_mamba_index) == 1, "Mamba should have single state index"
676
+
677
+ transfer_blocks = []
678
+ prefill_state_data_ptrs = self.kv_args.state_data_ptrs
679
+ prefill_state_item_lens = self.kv_args.state_item_lens
680
+
681
+ for i, dst_state_ptr in enumerate(dst_state_data_ptrs):
682
+ length = prefill_state_item_lens[i]
683
+ src_addr = prefill_state_data_ptrs[i] + length * int(prefill_mamba_index[0])
684
+ dst_addr = dst_state_ptr + length * int(req.dst_state_indices[0])
685
+ transfer_blocks.append((src_addr, dst_addr, length))
686
+
687
+ return self._transfer_data(req.mooncake_session_id, transfer_blocks)
688
+
596
689
  def sync_status_to_decode_endpoint(
597
690
  self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
598
691
  ):
@@ -702,6 +795,22 @@ class MooncakeKVManager(CommonKVManager):
702
795
  break
703
796
 
704
797
  if kv_chunk.is_last:
798
+ if kv_chunk.state_indices is not None:
799
+ if not self.is_mla_backend and (
800
+ self.attn_tp_size
801
+ != target_rank_registration_info.dst_attn_tp_size
802
+ ):
803
+ raise RuntimeError(
804
+ f"PD Disaggregation does NOT support PD different TP sizes for non-MLA hybrid models yet."
805
+ )
806
+
807
+ self.maybe_send_extra(
808
+ req,
809
+ kv_chunk.state_indices,
810
+ target_rank_registration_info.dst_state_data_ptrs,
811
+ executor,
812
+ )
813
+
705
814
  if self.pp_group.is_last_rank:
706
815
  # Only the last chunk we need to send the aux data
707
816
  ret = self.send_aux(
@@ -765,7 +874,7 @@ class MooncakeKVManager(CommonKVManager):
765
874
  )
766
875
  continue
767
876
  else:
768
- required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
877
+ required_dst_info_num = int(waiting_req_bytes[7].decode("ascii"))
769
878
  room = int(room)
770
879
  if room not in self.transfer_infos:
771
880
  self.transfer_infos[room] = {}
@@ -876,6 +985,7 @@ class MooncakeKVManager(CommonKVManager):
876
985
  index_slice: slice,
877
986
  is_last: bool,
878
987
  aux_index: Optional[int] = None,
988
+ state_indices: Optional[List[int]] = None,
879
989
  ):
880
990
  assert self.disaggregation_mode == DisaggregationMode.PREFILL
881
991
  assert not is_last or (is_last and aux_index is not None)
@@ -909,6 +1019,7 @@ class MooncakeKVManager(CommonKVManager):
909
1019
  index_slice=index_slice,
910
1020
  is_last=is_last,
911
1021
  prefill_aux_index=aux_index,
1022
+ state_indices=state_indices,
912
1023
  )
913
1024
  )
914
1025
 
@@ -989,6 +1100,7 @@ class MooncakeKVSender(CommonKVSender):
989
1100
  def send(
990
1101
  self,
991
1102
  kv_indices: npt.NDArray[np.int32],
1103
+ state_indices: Optional[List[int]] = None,
992
1104
  ):
993
1105
  index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
994
1106
  self.curr_idx += len(kv_indices)
@@ -1008,6 +1120,7 @@ class MooncakeKVSender(CommonKVSender):
1008
1120
  index_slice,
1009
1121
  True,
1010
1122
  aux_index=self.aux_index,
1123
+ state_indices=state_indices,
1011
1124
  )
1012
1125
 
1013
1126
  def poll(self) -> KVPoll:
@@ -1110,6 +1223,9 @@ class MooncakeKVReceiver(CommonKVReceiver):
1110
1223
  packed_aux_data_ptrs = b"".join(
1111
1224
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
1112
1225
  )
1226
+ packed_state_data_ptrs = b"".join(
1227
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs
1228
+ )
1113
1229
  # Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
1114
1230
  tp_rank = self.kv_mgr.kv_args.engine_rank
1115
1231
  kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
@@ -1127,13 +1243,27 @@ class MooncakeKVReceiver(CommonKVReceiver):
1127
1243
  self.session_id.encode("ascii"),
1128
1244
  packed_kv_data_ptrs,
1129
1245
  packed_aux_data_ptrs,
1246
+ packed_state_data_ptrs,
1130
1247
  dst_tp_rank,
1131
1248
  dst_attn_tp_size,
1132
1249
  dst_kv_item_len,
1133
1250
  ]
1134
1251
  )
1135
1252
 
1136
- def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
1253
+ def init(
1254
+ self,
1255
+ kv_indices: npt.NDArray[np.int32],
1256
+ aux_index: Optional[int] = None,
1257
+ state_indices: Optional[List[int]] = None,
1258
+ ):
1259
+ if self.bootstrap_infos is None:
1260
+ self.kv_mgr.record_failure(
1261
+ self.bootstrap_room,
1262
+ f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
1263
+ )
1264
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
1265
+ return
1266
+
1137
1267
  for bootstrap_info in self.bootstrap_infos:
1138
1268
  sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
1139
1269
  is_dummy = bootstrap_info["is_dummy"]
@@ -1147,6 +1277,14 @@ class MooncakeKVReceiver(CommonKVReceiver):
1147
1277
  self.session_id.encode("ascii"),
1148
1278
  kv_indices.tobytes() if not is_dummy else b"",
1149
1279
  str(aux_index).encode("ascii") if not is_dummy else b"",
1280
+ (
1281
+ np.array(
1282
+ state_indices,
1283
+ dtype=np.int32,
1284
+ ).tobytes()
1285
+ if not is_dummy and state_indices is not None
1286
+ else b""
1287
+ ),
1150
1288
  str(self.required_dst_info_num).encode("ascii"),
1151
1289
  ]
1152
1290
  )
@@ -319,14 +319,44 @@ class NixlKVManager(CommonKVManager):
319
319
 
320
320
  logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
321
321
  # Make descs
322
- num_layers = len(self.kv_args.kv_data_ptrs)
322
+ if self.is_mla_backend:
323
+ src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
324
+ self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
325
+ )
326
+ kv_item_len = self.kv_args.kv_item_lens[0]
327
+ layers_params = [
328
+ (
329
+ src_kv_ptrs[layer_id],
330
+ dst_kv_ptrs[layer_id],
331
+ kv_item_len,
332
+ )
333
+ for layer_id in range(layers_current_pp_stage)
334
+ ]
335
+ else:
336
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
337
+ self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
338
+ )
339
+
340
+ kv_item_len = self.kv_args.kv_item_lens[0]
341
+ layers_params = [
342
+ (
343
+ src_k_ptrs[layer_id],
344
+ dst_k_ptrs[layer_id],
345
+ kv_item_len,
346
+ )
347
+ for layer_id in range(layers_current_pp_stage)
348
+ ] + [
349
+ (
350
+ src_v_ptrs[layer_id],
351
+ dst_v_ptrs[layer_id],
352
+ kv_item_len,
353
+ )
354
+ for layer_id in range(layers_current_pp_stage)
355
+ ]
356
+
323
357
  src_addrs = []
324
358
  dst_addrs = []
325
- for layer_id in range(num_layers):
326
- src_ptr = self.kv_args.kv_data_ptrs[layer_id]
327
- dst_ptr = dst_kv_ptrs[layer_id]
328
- item_len = self.kv_args.kv_item_lens[layer_id]
329
-
359
+ for src_ptr, dst_ptr, item_len in layers_params:
330
360
  for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
331
361
  src_addr = src_ptr + int(prefill_index[0]) * item_len
332
362
  dst_addr = dst_ptr + int(decode_index[0]) * item_len
@@ -397,6 +427,9 @@ class NixlKVManager(CommonKVManager):
397
427
  num_heads_to_send = dst_heads_per_rank
398
428
  dst_head_start_offset = 0
399
429
 
430
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
431
+ self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
432
+ )
400
433
  # Create transfer descriptors
401
434
  src_addrs = []
402
435
  dst_addrs = []
@@ -404,12 +437,6 @@ class NixlKVManager(CommonKVManager):
404
437
  bytes_per_token_on_prefill = src_kv_item_len // page_size
405
438
  bytes_per_token_on_decode = dst_kv_item_len // page_size
406
439
 
407
- num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
408
- src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
409
- src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
410
- dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
411
- dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
412
-
413
440
  # Calculate precise byte offset and length for the sub-slice within the token
414
441
  src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
415
442
  dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
@@ -420,13 +447,13 @@ class NixlKVManager(CommonKVManager):
420
447
  src_k_ptrs[layer_id],
421
448
  dst_k_ptrs[layer_id],
422
449
  )
423
- for layer_id in range(len(src_k_ptrs))
450
+ for layer_id in range(layers_current_pp_stage)
424
451
  ] + [
425
452
  (
426
453
  src_v_ptrs[layer_id],
427
454
  dst_v_ptrs[layer_id],
428
455
  )
429
- for layer_id in range(len(src_v_ptrs))
456
+ for layer_id in range(layers_current_pp_stage)
430
457
  ]
431
458
 
432
459
  src_addrs = []
@@ -496,14 +523,19 @@ class NixlKVManager(CommonKVManager):
496
523
  dst_aux_index: int,
497
524
  notif: str,
498
525
  ):
499
- # Make descs
500
- aux_item_len = self.kv_args.aux_item_lens[0]
501
- prefill_aux_addr = (
502
- self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
503
- )
504
- decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
505
- src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
506
- dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
526
+ src_addrs = []
527
+ dst_addrs = []
528
+
529
+ prefill_aux_ptrs = self.kv_args.aux_data_ptrs
530
+ prefill_aux_item_lens = self.kv_args.aux_item_lens
531
+
532
+ for i, _ in enumerate(dst_aux_ptrs):
533
+ length = prefill_aux_item_lens[i]
534
+ src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
535
+ dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
536
+ src_addrs.append((src_addr, length, 0))
537
+ dst_addrs.append((dst_addr, length, 0))
538
+
507
539
  src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
508
540
  dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
509
541
  # Transfer data
@@ -576,7 +608,7 @@ class NixlKVManager(CommonKVManager):
576
608
 
577
609
  handles.append(kv_xfer_handle)
578
610
  # Only the last chunk we need to send the aux data.
579
- if is_last:
611
+ if is_last and self.pp_group.is_last_rank:
580
612
  assert aux_index is not None
581
613
  aux_xfer_handle = self.send_aux(
582
614
  req.agent_name,
@@ -672,6 +704,7 @@ class NixlKVSender(CommonKVSender):
672
704
  def send(
673
705
  self,
674
706
  kv_indices: npt.NDArray[np.int32],
707
+ state_indices: Optional[List[int]] = None,
675
708
  ):
676
709
  index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
677
710
  self.curr_idx += len(kv_indices)
@@ -723,7 +756,19 @@ class NixlKVReceiver(CommonKVReceiver):
723
756
  self.bootstrap_room
724
757
  )
725
758
 
726
- def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
759
+ def init(
760
+ self,
761
+ kv_indices: npt.NDArray[np.int32],
762
+ aux_index: Optional[int] = None,
763
+ state_indices: Optional[List[int]] = None,
764
+ ):
765
+ if self.bootstrap_infos is None:
766
+ logger.error(
767
+ f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
768
+ )
769
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
770
+ return
771
+
727
772
  for bootstrap_info in self.bootstrap_infos:
728
773
  logger.debug(
729
774
  f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"