sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -3,22 +3,26 @@ import logging
3
3
  import threading
4
4
  from enum import IntEnum
5
5
  from functools import wraps
6
+ from typing import Optional
6
7
 
7
8
  import psutil
8
9
  import torch
9
10
 
10
11
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
11
- from sglang.srt.utils import is_npu
12
+ from sglang.srt.utils import is_npu, is_xpu
12
13
 
13
14
  _is_npu = is_npu()
14
- if not _is_npu:
15
+ _is_xpu = is_xpu()
16
+ if not (_is_npu or _is_xpu):
15
17
  from sgl_kernel.kvcacheio import (
16
18
  transfer_kv_all_layer,
19
+ transfer_kv_all_layer_direct_lf_pf,
17
20
  transfer_kv_all_layer_lf_pf,
18
21
  transfer_kv_all_layer_mla,
19
22
  transfer_kv_all_layer_mla_lf_pf,
20
23
  transfer_kv_direct,
21
24
  transfer_kv_per_layer,
25
+ transfer_kv_per_layer_direct_pf_lf,
22
26
  transfer_kv_per_layer_mla,
23
27
  transfer_kv_per_layer_mla_pf_lf,
24
28
  transfer_kv_per_layer_pf_lf,
@@ -27,27 +31,13 @@ if not _is_npu:
27
31
  logger = logging.getLogger(__name__)
28
32
 
29
33
 
30
- class MemoryStateInt(IntEnum):
31
- IDLE = 0
32
- RESERVED = 1
33
- PROTECTED = 2
34
- SYNCED = 3
35
- BACKUP = 4
34
+ def synchronized(func):
35
+ @wraps(func)
36
+ def wrapper(self, *args, **kwargs):
37
+ with self.lock:
38
+ return func(self, *args, **kwargs)
36
39
 
37
-
38
- def synchronized(debug_only=False):
39
- def _decorator(func):
40
- @wraps(func)
41
- def wrapper(self, *args, **kwargs):
42
- if (not debug_only) or self.debug:
43
- with self.lock:
44
- return func(self, *args, **kwargs)
45
- else:
46
- return True
47
-
48
- return wrapper
49
-
50
- return _decorator
40
+ return wrapper
51
41
 
52
42
 
53
43
  class HostKVCache(abc.ABC):
@@ -76,6 +66,7 @@ class HostKVCache(abc.ABC):
76
66
  self.size = int(device_pool.size * host_to_device_ratio)
77
67
  # Align the host memory pool size to the page size
78
68
  self.size = self.size - (self.size % self.page_size)
69
+ self.page_num = self.size // self.page_size
79
70
  self.start_layer = device_pool.start_layer
80
71
  self.end_layer = device_pool.end_layer
81
72
 
@@ -105,7 +96,6 @@ class HostKVCache(abc.ABC):
105
96
 
106
97
  # A lock for synchronized operations on memory allocation and state transitions.
107
98
  self.lock = threading.RLock()
108
- self.debug = logger.isEnabledFor(logging.DEBUG)
109
99
  self.clear()
110
100
 
111
101
  @abc.abstractmethod
@@ -135,7 +125,7 @@ class HostKVCache(abc.ABC):
135
125
  raise NotImplementedError()
136
126
 
137
127
  @abc.abstractmethod
138
- def get_flat_data_page(self, index) -> torch.Tensor:
128
+ def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
139
129
  """
140
130
  Get a flat data page from the host memory pool.
141
131
  """
@@ -156,7 +146,7 @@ class HostKVCache(abc.ABC):
156
146
  """
157
147
  raise NotImplementedError()
158
148
 
159
- @synchronized()
149
+ @synchronized
160
150
  def clear(self):
161
151
  # Initialize memory states and tracking structures.
162
152
  self.mem_state = torch.zeros(
@@ -167,8 +157,8 @@ class HostKVCache(abc.ABC):
167
157
  def available_size(self):
168
158
  return len(self.free_slots)
169
159
 
170
- @synchronized()
171
- def alloc(self, need_size: int) -> torch.Tensor:
160
+ @synchronized
161
+ def alloc(self, need_size: int) -> Optional[torch.Tensor]:
172
162
  assert (
173
163
  need_size % self.page_size == 0
174
164
  ), "The requested size should be a multiple of the page size."
@@ -178,92 +168,13 @@ class HostKVCache(abc.ABC):
178
168
  select_index = self.free_slots[:need_size]
179
169
  self.free_slots = self.free_slots[need_size:]
180
170
 
181
- if self.debug:
182
- self.mem_state[select_index] = MemoryStateInt.RESERVED
183
-
184
171
  return select_index
185
172
 
186
- @synchronized()
173
+ @synchronized
187
174
  def free(self, indices: torch.Tensor) -> int:
188
175
  self.free_slots = torch.cat([self.free_slots, indices])
189
- if self.debug:
190
- self.mem_state[indices] = MemoryStateInt.IDLE
191
176
  return len(indices)
192
177
 
193
- @synchronized(debug_only=True)
194
- def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
195
- assert len(indices) > 0, "The indices should not be empty"
196
- states = self.mem_state[indices]
197
- assert (
198
- states == states[0]
199
- ).all(), "The memory slots should have the same state {}".format(states)
200
- return MemoryStateInt(states[0].item())
201
-
202
- @synchronized(debug_only=True)
203
- def is_reserved(self, indices: torch.Tensor) -> bool:
204
- return self.get_state(indices) == MemoryStateInt.RESERVED
205
-
206
- @synchronized(debug_only=True)
207
- def is_protected(self, indices: torch.Tensor) -> bool:
208
- return self.get_state(indices) == MemoryStateInt.PROTECTED
209
-
210
- @synchronized(debug_only=True)
211
- def is_synced(self, indices: torch.Tensor) -> bool:
212
- return self.get_state(indices) == MemoryStateInt.SYNCED
213
-
214
- @synchronized(debug_only=True)
215
- def is_backup(self, indices: torch.Tensor) -> bool:
216
- return self.get_state(indices) == MemoryStateInt.BACKUP
217
-
218
- @synchronized(debug_only=True)
219
- def update_backup(self, indices: torch.Tensor):
220
- if not self.is_synced(indices):
221
- raise ValueError(
222
- f"The host memory slots should be in SYNCED state before turning into BACKUP. "
223
- f"Current state: {self.get_state(indices)}"
224
- )
225
- self.mem_state[indices] = MemoryStateInt.BACKUP
226
-
227
- @synchronized(debug_only=True)
228
- def update_prefetch(self, indices: torch.Tensor):
229
- if not self.is_reserved(indices):
230
- raise ValueError(
231
- f"The host memory slots should be in RESERVED state before turning into BACKUP. "
232
- f"Current state: {self.get_state(indices)}"
233
- )
234
- self.mem_state[indices] = MemoryStateInt.BACKUP
235
-
236
- @synchronized(debug_only=True)
237
- def update_synced(self, indices: torch.Tensor):
238
- self.mem_state[indices] = MemoryStateInt.SYNCED
239
-
240
- @synchronized(debug_only=True)
241
- def protect_write(self, indices: torch.Tensor):
242
- if not self.is_reserved(indices):
243
- raise ValueError(
244
- f"The host memory slots should be RESERVED before write operations. "
245
- f"Current state: {self.get_state(indices)}"
246
- )
247
- self.mem_state[indices] = MemoryStateInt.PROTECTED
248
-
249
- @synchronized(debug_only=True)
250
- def protect_load(self, indices: torch.Tensor):
251
- if not self.is_backup(indices):
252
- raise ValueError(
253
- f"The host memory slots should be in BACKUP state before load operations. "
254
- f"Current state: {self.get_state(indices)}"
255
- )
256
- self.mem_state[indices] = MemoryStateInt.PROTECTED
257
-
258
- @synchronized(debug_only=True)
259
- def complete_io(self, indices: torch.Tensor):
260
- if not self.is_protected(indices):
261
- raise ValueError(
262
- f"The host memory slots should be PROTECTED during I/O operations. "
263
- f"Current state: {self.get_state(indices)}"
264
- )
265
- self.mem_state[indices] = MemoryStateInt.SYNCED
266
-
267
178
 
268
179
  class MHATokenToKVPoolHost(HostKVCache):
269
180
  device_pool: MHATokenToKVPool
@@ -315,6 +226,15 @@ class MHATokenToKVPoolHost(HostKVCache):
315
226
  dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
316
227
  elif self.layout == "page_first":
317
228
  dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
229
+ elif self.layout == "page_first_direct":
230
+ dims = (
231
+ 2,
232
+ self.page_num,
233
+ self.layer_num,
234
+ self.page_size,
235
+ self.head_num,
236
+ self.head_dim,
237
+ )
318
238
  else:
319
239
  raise ValueError(f"Unsupported layout: {self.layout}")
320
240
  self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
@@ -368,19 +288,31 @@ class MHATokenToKVPoolHost(HostKVCache):
368
288
  else:
369
289
  raise ValueError(f"Unsupported layout: {self.layout}")
370
290
  elif io_backend == "direct":
371
- assert (
372
- self.layout == "layer_first"
373
- ), f"Direct IO backend only supports layer_first layout."
374
- transfer_kv_direct(
375
- src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
376
- dst_layers=[
377
- device_pool.k_buffer[layer_id],
378
- device_pool.v_buffer[layer_id],
379
- ],
380
- src_indices=host_indices,
381
- dst_indices=device_indices,
382
- page_size=self.page_size,
383
- )
291
+ if self.layout == "layer_first":
292
+ transfer_kv_direct(
293
+ src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
294
+ dst_layers=[
295
+ device_pool.k_buffer[layer_id],
296
+ device_pool.v_buffer[layer_id],
297
+ ],
298
+ src_indices=host_indices,
299
+ dst_indices=device_indices,
300
+ page_size=self.page_size,
301
+ )
302
+ elif self.layout == "page_first_direct":
303
+ transfer_kv_per_layer_direct_pf_lf(
304
+ src_ptrs=[self.k_buffer, self.v_buffer],
305
+ dst_ptrs=[
306
+ device_pool.k_buffer[layer_id],
307
+ device_pool.v_buffer[layer_id],
308
+ ],
309
+ src_indices=host_indices,
310
+ dst_indices=device_indices,
311
+ layer_id=layer_id,
312
+ page_size=self.page_size,
313
+ )
314
+ else:
315
+ raise ValueError(f"Unsupported layout: {self.layout}")
384
316
  else:
385
317
  raise ValueError(f"Unsupported IO backend: {io_backend}")
386
318
 
@@ -414,26 +346,40 @@ class MHATokenToKVPoolHost(HostKVCache):
414
346
  else:
415
347
  raise ValueError(f"Unsupported layout: {self.layout}")
416
348
  elif io_backend == "direct":
417
- assert (
418
- self.layout == "layer_first"
419
- ), f"Direct IO backend only supports layer_first layout."
420
- transfer_kv_direct(
421
- src_layers=device_pool.k_buffer + device_pool.v_buffer,
422
- dst_layers=self.k_data_refs + self.v_data_refs,
423
- src_indices=device_indices,
424
- dst_indices=host_indices,
425
- page_size=self.page_size,
426
- )
349
+ if self.layout == "layer_first":
350
+ transfer_kv_direct(
351
+ src_layers=device_pool.k_buffer + device_pool.v_buffer,
352
+ dst_layers=self.k_data_refs + self.v_data_refs,
353
+ src_indices=device_indices,
354
+ dst_indices=host_indices,
355
+ page_size=self.page_size,
356
+ )
357
+ elif self.layout == "page_first_direct":
358
+ transfer_kv_all_layer_direct_lf_pf(
359
+ src_ptrs=device_pool.k_buffer + device_pool.v_buffer,
360
+ dst_ptrs=[self.k_buffer, self.v_buffer],
361
+ src_indices=device_indices,
362
+ dst_indices=host_indices,
363
+ page_size=self.page_size,
364
+ )
365
+ else:
366
+ raise ValueError(f"Unsupported layout: {self.layout}")
427
367
  else:
428
368
  raise ValueError(f"Unsupported IO backend: {io_backend}")
429
369
 
430
- def get_flat_data_page(self, index) -> torch.Tensor:
370
+ def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
431
371
  if self.layout == "layer_first":
432
- return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
372
+ data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :]
433
373
  elif self.layout == "page_first":
434
- return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
374
+ data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :]
375
+ elif self.layout == "page_first_direct":
376
+ real_index = index // self.page_size
377
+ data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :]
435
378
  else:
436
379
  raise ValueError(f"Unsupported layout: {self.layout}")
380
+ if flat:
381
+ data_page = data_page.flatten()
382
+ return data_page
437
383
 
438
384
  def get_dummy_flat_data_page(self) -> torch.Tensor:
439
385
  return torch.zeros(
@@ -460,12 +406,22 @@ class MHATokenToKVPoolHost(HostKVCache):
460
406
  2, self.page_size, self.layer_num, self.head_num, self.head_dim
461
407
  )
462
408
  )
409
+ elif self.layout == "page_first_direct":
410
+ real_index = index // self.page_size
411
+ self.kv_buffer[:, real_index : real_index + 1, :, :, :, :] = (
412
+ data_page.reshape(
413
+ 2, 1, self.layer_num, self.page_size, self.head_num, self.head_dim
414
+ )
415
+ )
463
416
  else:
464
417
  raise ValueError(f"Unsupported layout: {self.layout}")
465
418
 
466
- def get_buffer_meta(self, keys, indices, local_rank):
419
+ def get_page_buffer_meta(self, indices):
420
+ """ "
421
+ meta data for zero copy
422
+ """
423
+ assert len(indices) % self.page_size == 0
467
424
  ptr_list = []
468
- key_list = []
469
425
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
470
426
  indices = indices.tolist()
471
427
  v_offset = (
@@ -475,48 +431,52 @@ class MHATokenToKVPoolHost(HostKVCache):
475
431
  * self.head_dim
476
432
  * self.dtype.itemsize
477
433
  )
478
- for index in range(0, len(indices), self.page_size):
479
- k_ptr = (
480
- kv_buffer_data_ptr
481
- + indices[index]
482
- * self.layer_num
434
+ if self.layout == "layer_first":
435
+ for index in range(0, len(indices), self.page_size):
436
+ for layer_id in range(self.layer_num):
437
+ k_ptr = (
438
+ kv_buffer_data_ptr
439
+ + indices[index]
440
+ * self.head_num
441
+ * self.head_dim
442
+ * self.dtype.itemsize
443
+ + layer_id
444
+ * self.size
445
+ * self.head_num
446
+ * self.head_dim
447
+ * self.dtype.itemsize
448
+ )
449
+ v_ptr = k_ptr + v_offset
450
+ ptr_list.append(k_ptr)
451
+ ptr_list.append(v_ptr)
452
+ element_size = (
453
+ self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
454
+ )
455
+ element_size_list = [element_size] * len(ptr_list)
456
+ elif self.layout in ["page_first", "page_first_direct"]:
457
+ for index in range(0, len(indices), self.page_size):
458
+ k_ptr = (
459
+ kv_buffer_data_ptr
460
+ + indices[index]
461
+ * self.layer_num
462
+ * self.head_num
463
+ * self.head_dim
464
+ * self.dtype.itemsize
465
+ )
466
+ v_ptr = k_ptr + v_offset
467
+ ptr_list.append(k_ptr)
468
+ ptr_list.append(v_ptr)
469
+ element_size = (
470
+ self.layer_num
471
+ * self.dtype.itemsize
472
+ * self.page_size
483
473
  * self.head_num
484
474
  * self.head_dim
485
- * self.dtype.itemsize
486
475
  )
487
- v_ptr = k_ptr + v_offset
488
- ptr_list.append(k_ptr)
489
- ptr_list.append(v_ptr)
490
- key_ = keys[index // self.page_size]
491
- key_list.append(f"{key_}_{local_rank}_k")
492
- key_list.append(f"{key_}_{local_rank}_v")
493
- element_size = (
494
- self.layer_num
495
- * self.dtype.itemsize
496
- * self.page_size
497
- * self.head_num
498
- * self.head_dim
499
- )
500
- element_size_list = [element_size] * len(key_list)
501
- return key_list, ptr_list, element_size_list
502
-
503
- def get_buffer_with_hash(self, keys, indices=None):
504
- assert self.layout == "page_first"
505
- assert indices is None or (len(keys) == (len(indices) // self.page_size))
506
-
507
- key_list = []
508
- buf_list = []
509
-
510
- for i in range(len(keys)):
511
- key = keys[i]
512
- key_list.append(f"{key}-k")
513
- key_list.append(f"{key}-v")
514
- if indices is not None:
515
- index = indices[i * self.page_size]
516
- buf_list.append(self.k_buffer[index : index + self.page_size])
517
- buf_list.append(self.v_buffer[index : index + self.page_size])
518
-
519
- return key_list, buf_list, 2
476
+ element_size_list = [element_size] * len(ptr_list)
477
+ else:
478
+ raise ValueError(f"Unsupported layout: {self.layout}")
479
+ return ptr_list, element_size_list
520
480
 
521
481
 
522
482
  class MLATokenToKVPoolHost(HostKVCache):
@@ -578,6 +538,14 @@ class MLATokenToKVPoolHost(HostKVCache):
578
538
  1,
579
539
  self.kv_lora_rank + self.qk_rope_head_dim,
580
540
  )
541
+ elif self.layout == "page_first_direct":
542
+ dims = (
543
+ self.page_num,
544
+ self.layer_num,
545
+ self.page_size,
546
+ 1,
547
+ self.kv_lora_rank + self.qk_rope_head_dim,
548
+ )
581
549
  else:
582
550
  raise ValueError(f"Unsupported layout: {self.layout}")
583
551
  self.token_stride_size = (
@@ -617,16 +585,25 @@ class MLATokenToKVPoolHost(HostKVCache):
617
585
  else:
618
586
  raise ValueError(f"Unsupported layout: {self.layout}")
619
587
  elif io_backend == "direct":
620
- assert (
621
- self.layout == "layer_first"
622
- ), f"Direct IO backend only supports layer_first layout."
623
- transfer_kv_direct(
624
- src_layers=[self.kv_buffer[layer_id]],
625
- dst_layers=[device_pool.kv_buffer[layer_id]],
626
- src_indices=host_indices,
627
- dst_indices=device_indices,
628
- page_size=self.page_size,
629
- )
588
+ if self.layout == "layer_first":
589
+ transfer_kv_direct(
590
+ src_layers=[self.kv_buffer[layer_id]],
591
+ dst_layers=[device_pool.kv_buffer[layer_id]],
592
+ src_indices=host_indices,
593
+ dst_indices=device_indices,
594
+ page_size=self.page_size,
595
+ )
596
+ elif self.layout == "page_first_direct":
597
+ transfer_kv_per_layer_direct_pf_lf(
598
+ src_ptrs=[self.kv_buffer],
599
+ dst_ptrs=[device_pool.kv_buffer[layer_id]],
600
+ src_indices=host_indices,
601
+ dst_indices=device_indices,
602
+ layer_id=layer_id,
603
+ page_size=self.page_size,
604
+ )
605
+ else:
606
+ raise ValueError(f"Unsupported layout: {self.layout}")
630
607
 
631
608
  def backup_from_device_all_layer(
632
609
  self, device_pool, host_indices, device_indices, io_backend
@@ -654,26 +631,40 @@ class MLATokenToKVPoolHost(HostKVCache):
654
631
  else:
655
632
  raise ValueError(f"Unsupported layout: {self.layout}")
656
633
  elif io_backend == "direct":
657
- assert (
658
- self.layout == "layer_first"
659
- ), f"Direct IO backend only supports layer_first layout."
660
- transfer_kv_direct(
661
- src_layers=device_pool.kv_buffer,
662
- dst_layers=self.data_refs,
663
- src_indices=device_indices,
664
- dst_indices=host_indices,
665
- page_size=self.page_size,
666
- )
634
+ if self.layout == "layer_first":
635
+ transfer_kv_direct(
636
+ src_layers=device_pool.kv_buffer,
637
+ dst_layers=self.data_refs,
638
+ src_indices=device_indices,
639
+ dst_indices=host_indices,
640
+ page_size=self.page_size,
641
+ )
642
+ elif self.layout == "page_first_direct":
643
+ transfer_kv_all_layer_direct_lf_pf(
644
+ src_ptrs=device_pool.kv_buffer,
645
+ dst_ptrs=[self.kv_buffer],
646
+ src_indices=device_indices,
647
+ dst_indices=host_indices,
648
+ page_size=self.page_size,
649
+ )
650
+ else:
651
+ raise ValueError(f"Unsupported layout: {self.layout}")
667
652
  else:
668
653
  raise ValueError(f"Unsupported IO backend: {io_backend}")
669
654
 
670
- def get_flat_data_page(self, index) -> torch.Tensor:
655
+ def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
671
656
  if self.layout == "layer_first":
672
- return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
657
+ data_page = self.kv_buffer[:, index : index + self.page_size, :, :]
673
658
  elif self.layout == "page_first":
674
- return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
659
+ data_page = self.kv_buffer[index : index + self.page_size, :, :, :]
660
+ elif self.layout == "page_first_direct":
661
+ real_index = index // self.page_size
662
+ data_page = self.kv_buffer[real_index : real_index + 1, :, :, :, :]
675
663
  else:
676
664
  raise ValueError(f"Unsupported layout: {self.layout}")
665
+ if flat:
666
+ data_page = data_page.flatten()
667
+ return data_page
677
668
 
678
669
  def get_dummy_flat_data_page(self) -> torch.Tensor:
679
670
  return torch.zeros(
@@ -703,43 +694,63 @@ class MLATokenToKVPoolHost(HostKVCache):
703
694
  1,
704
695
  self.kv_lora_rank + self.qk_rope_head_dim,
705
696
  )
697
+ elif self.layout == "page_first_direct":
698
+ real_index = index // self.page_size
699
+ self.kv_buffer[real_index : real_index + 1, :, :, :, :] = data_page.reshape(
700
+ 1,
701
+ self.layer_num,
702
+ self.page_size,
703
+ 1,
704
+ self.kv_lora_rank + self.qk_rope_head_dim,
705
+ )
706
706
  else:
707
707
  raise ValueError(f"Unsupported layout: {self.layout}")
708
708
 
709
- def get_buffer_meta(self, keys, indices, local_rank):
709
+ def get_page_buffer_meta(self, indices):
710
+ """ "
711
+ meta data for zero copy
712
+ """
713
+ assert len(indices) % self.page_size == 0
710
714
  ptr_list = []
711
- key_list = []
712
715
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
713
716
  indices = indices.tolist()
714
- for index in range(0, len(indices), self.page_size):
715
- k_ptr = (
716
- kv_buffer_data_ptr
717
- + indices[index]
718
- * self.layer_num
717
+ if self.layout == "layer_first":
718
+ for index in range(0, len(indices), self.page_size):
719
+ for layer_id in range(self.layer_num):
720
+ k_ptr = (
721
+ kv_buffer_data_ptr
722
+ + indices[index]
723
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
724
+ * self.dtype.itemsize
725
+ + layer_id
726
+ * self.size
727
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
728
+ * self.dtype.itemsize
729
+ )
730
+ ptr_list.append(k_ptr)
731
+ element_size = (
732
+ self.dtype.itemsize
733
+ * self.page_size
719
734
  * (self.kv_lora_rank + self.qk_rope_head_dim)
735
+ )
736
+ element_size_list = [element_size] * len(ptr_list)
737
+ elif self.layout in ["page_first", "page_first_direct"]:
738
+ for index in range(0, len(indices), self.page_size):
739
+ k_ptr = (
740
+ kv_buffer_data_ptr
741
+ + indices[index]
742
+ * self.layer_num
743
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
744
+ * self.dtype.itemsize
745
+ )
746
+ ptr_list.append(k_ptr)
747
+ element_size = (
748
+ self.layer_num
720
749
  * self.dtype.itemsize
750
+ * self.page_size
751
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
721
752
  )
722
- ptr_list.append(k_ptr)
723
- key_ = keys[index // self.page_size]
724
- key_list.append(f"{key_}_k")
725
- element_size = (
726
- self.layer_num
727
- * self.dtype.itemsize
728
- * self.page_size
729
- * (self.kv_lora_rank + self.qk_rope_head_dim)
730
- )
731
- element_size_list = [element_size] * len(key_list)
732
- return key_list, ptr_list, element_size_list
733
-
734
- def get_buffer_with_hash(self, keys, indices=None):
735
- assert self.layout == "page_first"
736
- assert indices is None or (len(keys) == (len(indices) // self.page_size))
737
-
738
- buf_list = []
739
-
740
- if indices is not None:
741
- for i in range(len(keys)):
742
- index = indices[i * self.page_size]
743
- buf_list.append(self.kv_buffer[index : index + self.page_size])
744
-
745
- return keys, buf_list, 1
753
+ element_size_list = [element_size] * len(ptr_list)
754
+ else:
755
+ raise ValueError(f"Unsupported layout: {self.layout}")
756
+ return ptr_list, element_size_list