sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import math
18
18
  import threading
19
19
  import time
20
20
  from queue import Empty, Full, PriorityQueue, Queue
21
- from typing import TYPE_CHECKING, List, Optional
21
+ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
22
22
 
23
23
  import torch
24
24
 
@@ -33,6 +33,7 @@ from sglang.srt.distributed import (
33
33
  get_tensor_model_parallel_world_size,
34
34
  )
35
35
  from sglang.srt.layers.dp_attention import (
36
+ get_attention_dp_rank,
36
37
  get_attention_tp_rank,
37
38
  get_attention_tp_size,
38
39
  is_dp_attention_enabled,
@@ -42,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
42
43
  logger = logging.getLogger(__name__)
43
44
 
44
45
 
46
+ class LayerLoadingEvent:
47
+ def __init__(self, num_layers: int):
48
+ self._num_layers = num_layers
49
+ self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
50
+ self.start_event = torch.cuda.Event() # start event on controller stream
51
+
52
+ def complete(self, layer_index: int):
53
+ assert 0 <= layer_index < self._num_layers
54
+ self.load_events[layer_index].record()
55
+
56
+ def wait(self, layer_index: int):
57
+ torch.cuda.current_stream().wait_event(self.load_events[layer_index])
58
+
59
+ @property
60
+ def finish_event(self):
61
+ return self.load_events[-1]
62
+
63
+
45
64
  class LayerDoneCounter:
46
- def __init__(self, num_layers):
65
+ def __init__(self, num_layers: int):
47
66
  self.num_layers = num_layers
48
67
  # extra producer and consumer counters for overlap mode
49
68
  self.num_counters = 3
50
- self.counters = [num_layers] * self.num_counters
51
- self.conditions = [threading.Condition() for _ in range(self.num_counters)]
52
- self.producer_index = 0
53
- self.consumer_index = 0
54
-
55
- def next_producer(self):
56
- return (self.producer_index + 1) % self.num_counters
69
+ self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
70
+ self.producer_index = -1
71
+ self.consumer_index = -1
57
72
 
58
73
  def update_producer(self):
59
- self.producer_index = self.next_producer()
74
+ self.producer_index = (self.producer_index + 1) % self.num_counters
75
+ assert self.events[
76
+ self.producer_index
77
+ ].finish_event.query(), (
78
+ "Producer finish event should be ready before being reused."
79
+ )
60
80
  return self.producer_index
61
81
 
62
- def set_consumer(self, index):
82
+ def set_consumer(self, index: int):
63
83
  self.consumer_index = index
64
84
 
65
- def increment(self):
66
- with self.conditions[self.producer_index]:
67
- self.counters[self.producer_index] += 1
68
- self.conditions[self.producer_index].notify_all()
69
-
70
- def wait_until(self, threshold):
71
- with self.conditions[self.consumer_index]:
72
- while self.counters[self.consumer_index] <= threshold:
73
- self.conditions[self.consumer_index].wait()
85
+ def wait_until(self, threshold: int):
86
+ if self.consumer_index < 0:
87
+ return
88
+ self.events[self.consumer_index].wait(threshold)
74
89
 
75
90
  def reset(self):
76
- with self.conditions[self.producer_index]:
77
- self.counters[self.producer_index] = 0
91
+ self.producer_index = -1
92
+ self.consumer_index = -1
78
93
 
79
94
 
80
95
  class CacheOperation:
@@ -98,36 +113,30 @@ class CacheOperation:
98
113
  # default priority is the order of creation
99
114
  self.priority = priority if priority is not None else self.id
100
115
 
101
- def merge(self, other: "CacheOperation") -> None:
102
- # multiple operations can be merged into a single operation for batch processing
103
- self.host_indices = torch.cat([self.host_indices, other.host_indices])
104
- self.device_indices = torch.cat([self.device_indices, other.device_indices])
105
- self.priority = min(self.priority, other.priority)
106
- self.node_ids.extend(other.node_ids)
107
-
108
- def split(self, factor) -> List["CacheOperation"]:
109
- # split an operation into smaller operations to reduce the size of intermediate buffers
110
- if factor <= 1:
111
- return [self]
112
-
113
- chunk_size = math.ceil(len(self.host_indices) / factor)
114
- split_ops = []
115
- for i in range(0, len(self.host_indices), chunk_size):
116
- split_ops.append(
117
- CacheOperation(
118
- host_indices=self.host_indices[i : i + chunk_size],
119
- device_indices=self.device_indices[i : i + chunk_size],
120
- node_id=0,
121
- )
122
- )
123
- # Inherit the node_ids on the final chunk
124
- if split_ops:
125
- split_ops[-1].node_ids = self.node_ids
116
+ @staticmethod
117
+ def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
118
+ assert len(ops) > 0
119
+ if len(ops) == 1:
120
+ return ops[0]
121
+
122
+ host_indices = torch.cat([op.host_indices for op in ops])
123
+ device_indices = torch.cat([op.device_indices for op in ops])
124
+ node_ids = []
125
+ priority = min(op.priority for op in ops)
126
+ for op in ops:
127
+ node_ids.extend(op.node_ids)
128
+ merged_op = CacheOperation(host_indices, device_indices, -1, priority)
129
+ merged_op.node_ids = node_ids
130
+ return merged_op
131
+
132
+ def __lt__(self, other: CacheOperation):
133
+ return self.priority < other.priority
126
134
 
127
- return split_ops
128
135
 
129
- def __lt__(self, other: "CacheOperation"):
130
- return self.priority < other.priority
136
+ class HiCacheAck(NamedTuple):
137
+ start_event: torch.cuda.Event
138
+ finish_event: torch.cuda.Event
139
+ node_ids: List[int]
131
140
 
132
141
 
133
142
  class TransferBuffer:
@@ -206,26 +215,25 @@ class PrefetchOperation(StorageOperation):
206
215
  ):
207
216
  self.request_id = request_id
208
217
 
209
- self._done_flag = False
210
218
  self._lock = threading.Lock()
211
-
219
+ self._terminated_flag = False
212
220
  self.start_time = time.monotonic()
213
221
 
214
222
  super().__init__(host_indices, token_ids, last_hash)
215
223
 
216
224
  def increment(self, num_tokens: int):
217
225
  with self._lock:
218
- if self._done_flag:
226
+ if self._terminated_flag:
219
227
  return False
220
228
  self.completed_tokens += num_tokens
221
229
  return True
222
230
 
223
- def mark_done(self):
231
+ def mark_terminate(self):
224
232
  with self._lock:
225
- self._done_flag = True
233
+ self._terminated_flag = True
226
234
 
227
- def is_done(self) -> bool:
228
- return self._done_flag
235
+ def is_terminated(self) -> bool:
236
+ return self._terminated_flag
229
237
 
230
238
 
231
239
  class HiCacheController:
@@ -236,13 +244,13 @@ class HiCacheController:
236
244
  mem_pool_host: HostKVCache,
237
245
  page_size: int,
238
246
  tp_group: torch.distributed.ProcessGroup,
239
- load_cache_event: threading.Event = None,
247
+ load_cache_event: threading.Event,
240
248
  write_policy: str = "write_through_selective",
241
249
  io_backend: str = "",
242
250
  storage_backend: Optional[str] = None,
243
251
  prefetch_threshold: int = 256,
244
252
  model_name: Optional[str] = None,
245
- storage_backend_extra_config: Optional[str] = None,
253
+ storage_backend_extra_config: Optional[dict] = None,
246
254
  ):
247
255
  self.mem_pool_device_allocator = token_to_kv_pool_allocator
248
256
  self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
@@ -267,43 +275,17 @@ class HiCacheController:
267
275
  and self.storage_config.tp_rank != 0
268
276
  )
269
277
 
270
- if storage_backend == "file":
271
- from sglang.srt.mem_cache.hicache_storage import HiCacheFile
272
-
273
- self.storage_backend = HiCacheFile(self.storage_config)
274
- elif storage_backend == "nixl":
275
- from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
278
+ # Use storage backend factory for dynamic backend creation
279
+ from sglang.srt.mem_cache.storage import StorageBackendFactory
276
280
 
277
- self.storage_backend = HiCacheNixl()
278
- elif storage_backend == "mooncake":
279
- from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
280
- MooncakeStore,
281
- )
282
-
283
- self.storage_backend = MooncakeStore(self.storage_config)
284
- self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
285
- assert self.mem_pool_host.layout == "page_first"
286
- elif storage_backend == "hf3fs":
287
- from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
288
- HiCacheHF3FS,
281
+ try:
282
+ self.storage_backend = StorageBackendFactory.create_backend(
283
+ storage_backend, self.storage_config, self.mem_pool_host
289
284
  )
285
+ except ValueError as e:
286
+ raise ValueError(f"Failed to create storage backend: {e}") from e
290
287
 
291
- if self.mem_pool_host.layout == "page_first":
292
- bytes_per_page = (
293
- mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
294
- )
295
- elif self.mem_pool_host.layout == "layer_first":
296
- bytes_per_page = (
297
- mem_pool_host.get_size_per_token() * mem_pool_host.page_size
298
- )
299
- dtype = mem_pool_host.dtype
300
- self.storage_backend = HiCacheHF3FS.from_env_config(
301
- bytes_per_page, dtype, self.storage_config
302
- )
303
- else:
304
- raise NotImplementedError(
305
- f"Unsupported storage backend: {storage_backend}"
306
- )
288
+ self.storage_backend.register_mem_pool_host(self.mem_pool_host)
307
289
 
308
290
  self.enable_storage = True
309
291
  # todo: threshold policy for prefetching
@@ -324,8 +306,17 @@ class HiCacheController:
324
306
  group_ranks, backend="gloo"
325
307
  )
326
308
 
327
- self.load_cache_event = load_cache_event
328
- self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
309
+ # Select the get and set functions
310
+ self.page_get_func = self._generic_page_get
311
+ self.page_set_func = self._generic_page_set
312
+
313
+ if self.storage_backend_type in ["hf3fs", "mooncake", "eic"]:
314
+ self.page_get_func = self._page_get_zero_copy
315
+ self.page_set_func = self._page_set_zero_copy
316
+
317
+ self.device = self.mem_pool_device.device
318
+ self.layer_num = self.mem_pool_device.layer_num
319
+ self.layer_done_counter = LayerDoneCounter(self.layer_num)
329
320
  self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
330
321
 
331
322
  if write_policy not in [
@@ -335,11 +326,11 @@ class HiCacheController:
335
326
  ]:
336
327
  raise ValueError(f"Invalid write policy: {write_policy}")
337
328
 
338
- self.write_queue = PriorityQueue()
339
- self.load_queue = PriorityQueue()
340
-
341
- self.ack_write_queue = Queue()
342
- self.ack_load_queue = Queue()
329
+ # self.write_queue = PriorityQueue[CacheOperation]()
330
+ self.load_queue: List[CacheOperation] = []
331
+ self.write_queue: List[CacheOperation] = []
332
+ self.ack_load_queue: List[HiCacheAck] = []
333
+ self.ack_write_queue: List[HiCacheAck] = []
343
334
 
344
335
  self.stop_event = threading.Event()
345
336
  self.write_buffer = TransferBuffer(self.stop_event)
@@ -350,16 +341,6 @@ class HiCacheController:
350
341
  self.write_stream = torch.cuda.Stream()
351
342
  self.load_stream = torch.cuda.Stream()
352
343
 
353
- self.write_thread = threading.Thread(
354
- target=self.write_thread_func_direct, daemon=True
355
- )
356
- self.load_thread = threading.Thread(
357
- target=self.load_thread_func_layer_by_layer, daemon=True
358
- )
359
-
360
- self.write_thread.start()
361
- self.load_thread.start()
362
-
363
344
  if self.enable_storage:
364
345
  self.prefetch_thread = threading.Thread(
365
346
  target=self.prefetch_thread_func, daemon=True
@@ -380,48 +361,39 @@ class HiCacheController:
380
361
  def _generate_storage_config(
381
362
  self,
382
363
  model_name: Optional[str] = None,
383
- storage_backend_extra_config: Optional[str] = None,
364
+ storage_backend_extra_config: Optional[dict] = None,
384
365
  ):
385
366
 
386
367
  if is_dp_attention_enabled():
387
368
  self.tp_rank = get_attention_tp_rank()
388
369
  self.tp_size = get_attention_tp_size()
370
+ self.dp_rank = get_attention_dp_rank()
389
371
  else:
390
372
  self.tp_rank = get_tensor_model_parallel_rank()
391
373
  self.tp_size = get_tensor_model_parallel_world_size()
374
+ self.dp_rank = 0
392
375
 
393
376
  # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
394
377
  is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
395
378
 
396
- # Parse extra config JSON if provided
397
- extra_config = None
398
- if storage_backend_extra_config:
399
- try:
400
- import json
401
-
402
- extra_config = json.loads(storage_backend_extra_config)
403
- except Exception as e:
404
- logger.error(f"Invalid backend extra config JSON: {e}")
405
-
406
379
  return HiCacheStorageConfig(
407
380
  tp_rank=self.tp_rank,
408
381
  tp_size=self.tp_size,
409
382
  is_mla_model=is_mla_backend,
383
+ is_page_first_layout=self.mem_pool_host.layout == "page_first",
410
384
  model_name=model_name,
411
- extra_config=extra_config,
385
+ extra_config=storage_backend_extra_config,
412
386
  )
413
387
 
414
388
  def reset(self):
415
389
  self.stop_event.set()
416
- self.write_thread.join()
417
- self.load_thread.join()
418
390
 
419
- self.write_queue.queue.clear()
420
- self.load_queue.queue.clear()
391
+ self.write_queue.clear()
392
+ self.load_queue.clear()
421
393
  self.write_buffer.clear()
422
394
  self.load_buffer.clear()
423
- self.ack_write_queue.queue.clear()
424
- self.ack_load_queue.queue.clear()
395
+ self.ack_write_queue.clear()
396
+ self.ack_load_queue.clear()
425
397
  if self.enable_storage:
426
398
  self.prefetch_thread.join()
427
399
  self.backup_thread.join()
@@ -430,15 +402,7 @@ class HiCacheController:
430
402
  self.prefetch_revoke_queue.queue.clear()
431
403
  self.ack_backup_queue.queue.clear()
432
404
 
433
- self.write_thread = threading.Thread(
434
- target=self.write_thread_func_direct, daemon=True
435
- )
436
- self.load_thread = threading.Thread(
437
- target=self.load_thread_func_layer_by_layer, daemon=True
438
- )
439
405
  self.stop_event.clear()
440
- self.write_thread.start()
441
- self.load_thread.start()
442
406
 
443
407
  if self.enable_storage:
444
408
  self.prefetch_thread = threading.Thread(
@@ -454,7 +418,7 @@ class HiCacheController:
454
418
  self,
455
419
  device_indices: torch.Tensor,
456
420
  priority: Optional[int] = None,
457
- node_id: int = 0,
421
+ node_id: int = -1,
458
422
  ) -> Optional[torch.Tensor]:
459
423
  """
460
424
  Back up KV caches from device memory to host memory.
@@ -462,18 +426,45 @@ class HiCacheController:
462
426
  host_indices = self.mem_pool_host.alloc(len(device_indices))
463
427
  if host_indices is None:
464
428
  return None
465
- self.mem_pool_host.protect_write(host_indices)
466
- torch.cuda.current_stream().synchronize()
467
- self.write_queue.put(
429
+ self.write_queue.append(
468
430
  CacheOperation(host_indices, device_indices, node_id, priority)
469
431
  )
432
+ self.start_writing()
470
433
  return host_indices
471
434
 
435
+ def start_writing(self) -> None:
436
+ if len(self.write_queue) == 0:
437
+ return
438
+
439
+ op = CacheOperation.merge_ops(self.write_queue)
440
+ host_indices, device_indices = self.move_indices(op)
441
+ self.write_queue.clear()
442
+
443
+ start_event = torch.cuda.Event()
444
+ finish_event = torch.cuda.Event()
445
+
446
+ start_event.record()
447
+ with torch.cuda.stream(self.write_stream):
448
+ start_event.wait(self.write_stream)
449
+ self.mem_pool_host.backup_from_device_all_layer(
450
+ self.mem_pool_device, host_indices, device_indices, self.io_backend
451
+ )
452
+ finish_event.record()
453
+ # NOTE: We must save the host indices and device indices here,
454
+ # this is because we need to guarantee that these tensors are
455
+ # still alive when the write stream is executing.
456
+ if host_indices.is_cuda:
457
+ host_indices.record_stream(self.write_stream)
458
+ if device_indices.is_cuda:
459
+ device_indices.record_stream(self.write_stream)
460
+
461
+ self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
462
+
472
463
  def load(
473
464
  self,
474
465
  host_indices: torch.Tensor,
475
466
  priority: Optional[int] = None,
476
- node_id: int = 0,
467
+ node_id: int = -1,
477
468
  ) -> Optional[torch.Tensor]:
478
469
  """
479
470
  Load KV caches from host memory to device memory.
@@ -481,77 +472,42 @@ class HiCacheController:
481
472
  device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
482
473
  if device_indices is None:
483
474
  return None
484
- self.mem_pool_host.protect_load(host_indices)
485
- # to ensure the device indices are ready before accessed by another CUDA stream
486
- torch.cuda.current_stream().synchronize()
487
- self.load_queue.put(
475
+ self.load_queue.append(
488
476
  CacheOperation(host_indices, device_indices, node_id, priority)
489
477
  )
490
478
  return device_indices
491
479
 
492
- def move_indices(self, host_indices, device_indices):
480
+ def move_indices(self, op: CacheOperation):
481
+ host_indices, device_indices = op.host_indices, op.device_indices
493
482
  # move indices to GPU if using kernels, to host if using direct indexing
494
483
  if self.io_backend == "kernel":
495
- return host_indices.to(self.mem_pool_device.device), device_indices
484
+ if not host_indices.is_cuda:
485
+ host_indices = host_indices.to(self.device, non_blocking=True)
486
+ return host_indices, device_indices
496
487
  elif self.io_backend == "direct":
497
- device_indices = device_indices.cpu()
498
- host_indices, idx = host_indices.sort()
499
- return host_indices, device_indices.index_select(0, idx)
488
+ if self.mem_pool_host.layout == "layer_first":
489
+ device_indices = device_indices.cpu()
490
+ host_indices, idx = host_indices.sort()
491
+ return host_indices, device_indices.index_select(0, idx)
492
+ elif self.mem_pool_host.layout == "page_first_direct":
493
+ return host_indices, device_indices.cpu()
500
494
  else:
501
495
  raise ValueError(f"Unsupported io backend")
502
496
 
503
- def write_thread_func_direct(self):
504
- """
505
- Directly write through KV caches to host memory without buffering.
506
- """
507
- torch.cuda.set_stream(self.write_stream)
508
- while not self.stop_event.is_set():
509
- try:
510
- operation = self.write_queue.get(block=True, timeout=1)
511
- host_indices, device_indices = self.move_indices(
512
- operation.host_indices, operation.device_indices
513
- )
514
- self.mem_pool_host.backup_from_device_all_layer(
515
- self.mem_pool_device, host_indices, device_indices, self.io_backend
516
- )
517
- self.write_stream.synchronize()
518
- self.mem_pool_host.complete_io(operation.host_indices)
519
- for node_id in operation.node_ids:
520
- if node_id != 0:
521
- self.ack_write_queue.put(node_id)
522
- except Empty:
523
- continue
524
- except Exception as e:
525
- logger.error(e)
497
+ def start_loading(self) -> int:
498
+ if len(self.load_queue) == 0:
499
+ return -1
526
500
 
527
- def load_thread_func_layer_by_layer(self):
528
- """
529
- Load KV caches from host memory to device memory layer by layer.
530
- """
531
- torch.cuda.set_stream(self.load_stream)
532
- while not self.stop_event.is_set():
533
- self.load_cache_event.wait(timeout=1)
534
- if not self.load_cache_event.is_set():
535
- continue
536
- self.load_cache_event.clear()
537
- self.layer_done_counter.update_producer()
538
-
539
- batch_operation = None
540
- while self.load_queue.qsize() > 0:
541
- op = self.load_queue.get(block=True)
542
- if batch_operation is None:
543
- batch_operation = op
544
- else:
545
- batch_operation.merge(op)
546
- if batch_operation is None:
547
- continue
501
+ producer_id = self.layer_done_counter.update_producer()
502
+ op = CacheOperation.merge_ops(self.load_queue)
503
+ host_indices, device_indices = self.move_indices(op)
504
+ self.load_queue.clear()
505
+ producer_event = self.layer_done_counter.events[producer_id]
506
+ producer_event.start_event.record()
548
507
 
549
- # start layer-wise KV cache transfer from CPU to GPU
550
- self.layer_done_counter.reset()
551
- host_indices, device_indices = self.move_indices(
552
- batch_operation.host_indices, batch_operation.device_indices
553
- )
554
- for i in range(self.mem_pool_host.layer_num):
508
+ with torch.cuda.stream(self.load_stream):
509
+ producer_event.start_event.wait(self.load_stream)
510
+ for i in range(self.layer_num):
555
511
  self.mem_pool_host.load_to_device_per_layer(
556
512
  self.mem_pool_device,
557
513
  host_indices,
@@ -559,37 +515,34 @@ class HiCacheController:
559
515
  i,
560
516
  self.io_backend,
561
517
  )
562
- self.load_stream.synchronize()
563
- self.layer_done_counter.increment()
564
-
565
- self.mem_pool_host.complete_io(batch_operation.host_indices)
566
- for node_id in batch_operation.node_ids:
567
- if node_id != 0:
568
- self.ack_load_queue.put(node_id)
569
-
570
- def evict_device(
571
- self, device_indices: torch.Tensor, host_indices: torch.Tensor
572
- ) -> int:
573
- if self.mem_pool_host.is_synced(host_indices):
574
- self.mem_pool_device_allocator.free(device_indices)
575
- self.mem_pool_host.update_backup(host_indices)
576
- return len(device_indices)
577
- else:
578
- raise ValueError(
579
- f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
518
+ producer_event.complete(i)
519
+ # NOTE: We must save the host indices and device indices here,
520
+ # this is because we need to guarantee that these tensors are
521
+ # still alive when the load stream is executing.
522
+ if host_indices.is_cuda:
523
+ host_indices.record_stream(self.load_stream)
524
+ if device_indices.is_cuda:
525
+ device_indices.record_stream(self.load_stream)
526
+
527
+ self.ack_load_queue.append(
528
+ HiCacheAck(
529
+ start_event=producer_event.start_event,
530
+ finish_event=producer_event.finish_event,
531
+ node_ids=op.node_ids,
580
532
  )
533
+ )
534
+ return producer_id
535
+
536
+ def evict_device(self, device_indices: torch.Tensor) -> int:
537
+ self.mem_pool_device_allocator.free(device_indices)
538
+ return len(device_indices)
581
539
 
582
540
  def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
583
541
  if not backup_only:
584
542
  raise ValueError("Other eviction policies are not supported yet.")
585
543
 
586
- if self.mem_pool_host.is_backup(host_indices):
587
- self.mem_pool_host.free(host_indices)
588
- return len(host_indices)
589
- else:
590
- raise ValueError(
591
- f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
592
- )
544
+ self.mem_pool_host.free(host_indices)
545
+ return len(host_indices)
593
546
 
594
547
  def prefetch(
595
548
  self,
@@ -608,48 +561,33 @@ class HiCacheController:
608
561
  return operation
609
562
 
610
563
  def terminate_prefetch(self, operation):
611
- operation.mark_done()
564
+ operation.mark_terminate()
612
565
  return operation.completed_tokens, operation.hash_value
613
566
 
614
567
  def append_host_mem_release(self, host_indices: torch.Tensor):
615
- chunks = host_indices.split(self.mem_pool_host.page_size)
616
- for chunk in chunks:
617
- self.host_mem_release_queue.put(chunk)
618
-
619
- def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
620
- hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
621
- hash_values, host_indices
622
- )
623
- page_data = self.storage_backend.batch_get(hashes, dsts)
624
- if page_data:
625
- operation.increment(self.page_size * len(hashes))
626
- else:
627
- logger.warning(
628
- f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
629
- )
568
+ if host_indices.numel() == 0:
569
+ return
570
+ pages = host_indices.split(self.mem_pool_host.page_size)
571
+ for page in pages:
572
+ self.host_mem_release_queue.put(page)
630
573
 
631
- def _mooncake_page_get(self, operation, hash_values, host_indices):
632
- key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
633
- hash_values,
634
- host_indices,
635
- self.storage_config.tp_rank,
636
- )
637
- get_result = self.storage_backend.batch_get(
638
- key_strs,
639
- target_location=buffer_ptrs,
640
- target_sizes=buffer_sizes,
641
- )
642
- if get_result != len(hash_values):
643
- logger.warning(
644
- f"Prefetch operation {operation.request_id} failed or partially failed."
645
- )
646
- if get_result != 0:
647
- operation.increment(get_result * self.page_size)
574
+ def _page_get_zero_copy(self, operation, hash_values, host_indices):
575
+ results = self.storage_backend.batch_get_v1(hash_values, host_indices)
576
+ inc = 0
577
+ for i in range(len(hash_values)):
578
+ if not results[i]:
579
+ logger.warning(
580
+ f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
581
+ )
582
+ break
583
+ inc += self.page_size
584
+ operation.increment(inc)
648
585
 
586
+ # todo: deprecate
649
587
  def _generic_page_get(self, operation, hash_values, host_indices):
650
- dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
651
- hash_values
652
- )
588
+ dummy_page_dst = [
589
+ self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
590
+ ]
653
591
  page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
654
592
  if page_data is None:
655
593
  return
@@ -659,26 +597,16 @@ class HiCacheController:
659
597
  f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
660
598
  )
661
599
  break
662
- if operation.increment(self.page_size):
663
- self.mem_pool_host.set_from_flat_data_page(
664
- host_indices[i * self.page_size],
665
- page_data[i],
666
- )
667
- else:
668
- break
600
+ # Must set the data before increasing the completed tokens.
601
+ # Otherwise this page may be read before being set.
602
+ self.mem_pool_host.set_from_flat_data_page(
603
+ host_indices[i * self.page_size],
604
+ page_data[i],
605
+ )
606
+ if not operation.increment(self.page_size):
607
+ break # Operation terminated by controller
669
608
 
670
609
  def _page_transfer(self, operation):
671
- # Select the get function and batch size
672
- if self.storage_backend_type == "mooncake":
673
- get_func = self._mooncake_page_get
674
- elif (
675
- self.storage_backend_type == "hf3fs"
676
- and self.mem_pool_host.layout == "page_first"
677
- ):
678
- get_func = self._3fs_zero_copy_page_get
679
- else:
680
- get_func = self._generic_page_get
681
-
682
610
  # Transfer batch by batch
683
611
  for i in range(0, len(operation.hash_value), self.storage_batch_size):
684
612
  batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
@@ -687,12 +615,13 @@ class HiCacheController:
687
615
  ]
688
616
  prev_completed_tokens = operation.completed_tokens
689
617
  # Get one batch token, and update the completed_tokens if succeed
690
- get_func(operation, batch_hashes, batch_host_indices)
618
+ self.page_get_func(operation, batch_hashes, batch_host_indices)
691
619
  # Check termination
692
620
  if (
693
621
  operation.completed_tokens
694
622
  != prev_completed_tokens + len(batch_hashes) * self.page_size
695
623
  ):
624
+ operation.mark_terminate()
696
625
  break # Some operations fail or operation terminated by controller
697
626
  # release pre-allocated memory
698
627
  self.append_host_mem_release(
@@ -813,47 +742,19 @@ class HiCacheController:
813
742
  self.backup_queue.put(operation)
814
743
  return operation.id
815
744
 
816
- # non-zero copy
745
+ # todo: deprecate
817
746
  def _generic_page_set(self, hash_values, host_indices) -> bool:
818
747
  data = [
819
- self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
748
+ self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
820
749
  for i in range(len(hash_values))
821
750
  ]
822
751
  return self.storage_backend.batch_set(hash_values, data)
823
752
 
824
- # zero copy
825
- def _mooncake_page_set(self, hash_values, host_indices) -> bool:
826
- key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
827
- hash_values,
828
- host_indices,
829
- self.storage_config.tp_rank,
830
- )
831
- success = self.storage_backend.batch_set(
832
- key_strs,
833
- target_location=buffer_ptrs,
834
- target_sizes=buffer_sizes,
835
- )
836
- return success
837
-
838
- # zero copy
839
- def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
840
- hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
841
- hash_values, host_indices
842
- )
843
- return self.storage_backend.batch_set(hashes, dsts)
753
+ def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
754
+ return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
844
755
 
845
756
  # Backup batch by batch
846
757
  def _page_backup(self, operation):
847
- # Select the set function and batch size
848
- if self.storage_backend_type == "mooncake":
849
- backup_set_func = self._mooncake_page_set
850
- elif (
851
- self.storage_backend_type == "hf3fs"
852
- and self.mem_pool_host.layout == "page_first"
853
- ):
854
- backup_set_func = self._3fs_zero_copy_page_set
855
- else:
856
- backup_set_func = self._generic_page_set
857
758
  # Backup batch by batch
858
759
  for i in range(0, len(operation.hash_value), self.storage_batch_size):
859
760
  batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
@@ -862,7 +763,7 @@ class HiCacheController:
862
763
  ]
863
764
  # Set one batch token, and record if success.
864
765
  # todo: allow partial success
865
- success = backup_set_func(batch_hashes, batch_host_indices)
766
+ success = self.page_set_func(batch_hashes, batch_host_indices)
866
767
  if not success:
867
768
  logger.warning(
868
769
  f"Write page to storage: {len(batch_hashes)} pages failed."
@@ -882,7 +783,7 @@ class HiCacheController:
882
783
 
883
784
  if not self.backup_skip:
884
785
  self._page_backup(operation)
885
- self.ack_backup_queue.put(operation.id)
786
+ self.ack_backup_queue.put(operation)
886
787
 
887
788
  except Empty:
888
789
  continue