sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import math
18
18
  import threading
19
19
  import time
20
20
  from queue import Empty, Full, PriorityQueue, Queue
21
- from typing import TYPE_CHECKING, List, Optional
21
+ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
22
22
 
23
23
  import torch
24
24
 
@@ -33,6 +33,7 @@ from sglang.srt.distributed import (
33
33
  get_tensor_model_parallel_world_size,
34
34
  )
35
35
  from sglang.srt.layers.dp_attention import (
36
+ get_attention_dp_rank,
36
37
  get_attention_tp_rank,
37
38
  get_attention_tp_size,
38
39
  is_dp_attention_enabled,
@@ -42,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
42
43
  logger = logging.getLogger(__name__)
43
44
 
44
45
 
46
+ class LayerLoadingEvent:
47
+ def __init__(self, num_layers: int):
48
+ self._num_layers = num_layers
49
+ self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
50
+ self.start_event = torch.cuda.Event() # start event on controller stream
51
+
52
+ def complete(self, layer_index: int):
53
+ assert 0 <= layer_index < self._num_layers
54
+ self.load_events[layer_index].record()
55
+
56
+ def wait(self, layer_index: int):
57
+ torch.cuda.current_stream().wait_event(self.load_events[layer_index])
58
+
59
+ @property
60
+ def finish_event(self):
61
+ return self.load_events[-1]
62
+
63
+
45
64
  class LayerDoneCounter:
46
- def __init__(self, num_layers):
65
+ def __init__(self, num_layers: int):
47
66
  self.num_layers = num_layers
48
67
  # extra producer and consumer counters for overlap mode
49
68
  self.num_counters = 3
50
- self.counters = [num_layers] * self.num_counters
51
- self.conditions = [threading.Condition() for _ in range(self.num_counters)]
52
- self.producer_index = 0
53
- self.consumer_index = 0
54
-
55
- def next_producer(self):
56
- return (self.producer_index + 1) % self.num_counters
69
+ self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
70
+ self.producer_index = -1
71
+ self.consumer_index = -1
57
72
 
58
73
  def update_producer(self):
59
- self.producer_index = self.next_producer()
74
+ self.producer_index = (self.producer_index + 1) % self.num_counters
75
+ assert self.events[
76
+ self.producer_index
77
+ ].finish_event.query(), (
78
+ "Producer finish event should be ready before being reused."
79
+ )
60
80
  return self.producer_index
61
81
 
62
- def set_consumer(self, index):
82
+ def set_consumer(self, index: int):
63
83
  self.consumer_index = index
64
84
 
65
- def increment(self):
66
- with self.conditions[self.producer_index]:
67
- self.counters[self.producer_index] += 1
68
- self.conditions[self.producer_index].notify_all()
69
-
70
- def wait_until(self, threshold):
71
- with self.conditions[self.consumer_index]:
72
- while self.counters[self.consumer_index] <= threshold:
73
- self.conditions[self.consumer_index].wait()
85
+ def wait_until(self, threshold: int):
86
+ if self.consumer_index < 0:
87
+ return
88
+ self.events[self.consumer_index].wait(threshold)
74
89
 
75
90
  def reset(self):
76
- with self.conditions[self.producer_index]:
77
- self.counters[self.producer_index] = 0
91
+ self.producer_index = -1
92
+ self.consumer_index = -1
78
93
 
79
94
 
80
95
  class CacheOperation:
@@ -98,36 +113,30 @@ class CacheOperation:
98
113
  # default priority is the order of creation
99
114
  self.priority = priority if priority is not None else self.id
100
115
 
101
- def merge(self, other: "CacheOperation") -> None:
102
- # multiple operations can be merged into a single operation for batch processing
103
- self.host_indices = torch.cat([self.host_indices, other.host_indices])
104
- self.device_indices = torch.cat([self.device_indices, other.device_indices])
105
- self.priority = min(self.priority, other.priority)
106
- self.node_ids.extend(other.node_ids)
107
-
108
- def split(self, factor) -> List["CacheOperation"]:
109
- # split an operation into smaller operations to reduce the size of intermediate buffers
110
- if factor <= 1:
111
- return [self]
112
-
113
- chunk_size = math.ceil(len(self.host_indices) / factor)
114
- split_ops = []
115
- for i in range(0, len(self.host_indices), chunk_size):
116
- split_ops.append(
117
- CacheOperation(
118
- host_indices=self.host_indices[i : i + chunk_size],
119
- device_indices=self.device_indices[i : i + chunk_size],
120
- node_id=0,
121
- )
122
- )
123
- # Inherit the node_ids on the final chunk
124
- if split_ops:
125
- split_ops[-1].node_ids = self.node_ids
116
+ @staticmethod
117
+ def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
118
+ assert len(ops) > 0
119
+ if len(ops) == 1:
120
+ return ops[0]
121
+
122
+ host_indices = torch.cat([op.host_indices for op in ops])
123
+ device_indices = torch.cat([op.device_indices for op in ops])
124
+ node_ids = []
125
+ priority = min(op.priority for op in ops)
126
+ for op in ops:
127
+ node_ids.extend(op.node_ids)
128
+ merged_op = CacheOperation(host_indices, device_indices, -1, priority)
129
+ merged_op.node_ids = node_ids
130
+ return merged_op
131
+
132
+ def __lt__(self, other: CacheOperation):
133
+ return self.priority < other.priority
126
134
 
127
- return split_ops
128
135
 
129
- def __lt__(self, other: "CacheOperation"):
130
- return self.priority < other.priority
136
+ class HiCacheAck(NamedTuple):
137
+ start_event: torch.cuda.Event
138
+ finish_event: torch.cuda.Event
139
+ node_ids: List[int]
131
140
 
132
141
 
133
142
  class TransferBuffer:
@@ -206,26 +215,25 @@ class PrefetchOperation(StorageOperation):
206
215
  ):
207
216
  self.request_id = request_id
208
217
 
209
- self._done_flag = False
210
218
  self._lock = threading.Lock()
211
-
219
+ self._terminated_flag = False
212
220
  self.start_time = time.monotonic()
213
221
 
214
222
  super().__init__(host_indices, token_ids, last_hash)
215
223
 
216
224
  def increment(self, num_tokens: int):
217
225
  with self._lock:
218
- if self._done_flag:
226
+ if self._terminated_flag:
219
227
  return False
220
228
  self.completed_tokens += num_tokens
221
229
  return True
222
230
 
223
- def mark_done(self):
231
+ def mark_terminate(self):
224
232
  with self._lock:
225
- self._done_flag = True
233
+ self._terminated_flag = True
226
234
 
227
- def is_done(self) -> bool:
228
- return self._done_flag
235
+ def is_terminated(self) -> bool:
236
+ return self._terminated_flag
229
237
 
230
238
 
231
239
  class HiCacheController:
@@ -236,13 +244,13 @@ class HiCacheController:
236
244
  mem_pool_host: HostKVCache,
237
245
  page_size: int,
238
246
  tp_group: torch.distributed.ProcessGroup,
239
- load_cache_event: threading.Event = None,
247
+ load_cache_event: threading.Event,
240
248
  write_policy: str = "write_through_selective",
241
249
  io_backend: str = "",
242
250
  storage_backend: Optional[str] = None,
243
251
  prefetch_threshold: int = 256,
244
252
  model_name: Optional[str] = None,
245
- storage_backend_extra_config: Optional[str] = None,
253
+ storage_backend_extra_config: Optional[dict] = None,
246
254
  ):
247
255
  self.mem_pool_device_allocator = token_to_kv_pool_allocator
248
256
  self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
@@ -267,43 +275,17 @@ class HiCacheController:
267
275
  and self.storage_config.tp_rank != 0
268
276
  )
269
277
 
270
- if storage_backend == "file":
271
- from sglang.srt.mem_cache.hicache_storage import HiCacheFile
272
-
273
- self.storage_backend = HiCacheFile(self.storage_config)
274
- elif storage_backend == "nixl":
275
- from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
276
-
277
- self.storage_backend = HiCacheNixl()
278
- elif storage_backend == "mooncake":
279
- from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
280
- MooncakeStore,
281
- )
278
+ # Use storage backend factory for dynamic backend creation
279
+ from sglang.srt.mem_cache.storage import StorageBackendFactory
282
280
 
283
- self.storage_backend = MooncakeStore(self.storage_config)
284
- self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
285
- assert self.mem_pool_host.layout == "page_first"
286
- elif storage_backend == "hf3fs":
287
- from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
288
- HiCacheHF3FS,
281
+ try:
282
+ self.storage_backend = StorageBackendFactory.create_backend(
283
+ storage_backend, self.storage_config, self.mem_pool_host
289
284
  )
285
+ except ValueError as e:
286
+ raise ValueError(f"Failed to create storage backend: {e}") from e
290
287
 
291
- if self.mem_pool_host.layout == "page_first":
292
- bytes_per_page = (
293
- mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
294
- )
295
- elif self.mem_pool_host.layout == "layer_first":
296
- bytes_per_page = (
297
- mem_pool_host.get_size_per_token() * mem_pool_host.page_size
298
- )
299
- dtype = mem_pool_host.dtype
300
- self.storage_backend = HiCacheHF3FS.from_env_config(
301
- bytes_per_page, dtype, self.storage_config
302
- )
303
- else:
304
- raise NotImplementedError(
305
- f"Unsupported storage backend: {storage_backend}"
306
- )
288
+ self.storage_backend.register_mem_pool_host(self.mem_pool_host)
307
289
 
308
290
  self.enable_storage = True
309
291
  # todo: threshold policy for prefetching
@@ -327,21 +309,14 @@ class HiCacheController:
327
309
  # Select the get and set functions
328
310
  self.page_get_func = self._generic_page_get
329
311
  self.page_set_func = self._generic_page_set
330
- self.batch_exists_func = self.storage_backend.batch_exists
331
- self.is_3fs_zerocopy = (
332
- self.storage_backend_type == "hf3fs"
333
- and self.mem_pool_host.layout == "page_first"
334
- )
335
- if self.storage_backend_type == "mooncake":
336
- self.page_get_func = self._mooncake_page_get
337
- self.page_set_func = self._mooncake_page_set
338
- elif self.is_3fs_zerocopy:
339
- self.page_get_func = self._3fs_zero_copy_page_get
340
- self.page_set_func = self._3fs_zero_copy_page_set
341
- self.batch_exists_func = self._3fs_zero_copy_batch_exists
342
-
343
- self.load_cache_event = load_cache_event
344
- self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
312
+
313
+ if self.storage_backend_type in ["hf3fs", "mooncake", "eic"]:
314
+ self.page_get_func = self._page_get_zero_copy
315
+ self.page_set_func = self._page_set_zero_copy
316
+
317
+ self.device = self.mem_pool_device.device
318
+ self.layer_num = self.mem_pool_device.layer_num
319
+ self.layer_done_counter = LayerDoneCounter(self.layer_num)
345
320
  self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
346
321
 
347
322
  if write_policy not in [
@@ -351,11 +326,11 @@ class HiCacheController:
351
326
  ]:
352
327
  raise ValueError(f"Invalid write policy: {write_policy}")
353
328
 
354
- self.write_queue = PriorityQueue()
355
- self.load_queue = PriorityQueue()
356
-
357
- self.ack_write_queue = Queue()
358
- self.ack_load_queue = Queue()
329
+ # self.write_queue = PriorityQueue[CacheOperation]()
330
+ self.load_queue: List[CacheOperation] = []
331
+ self.write_queue: List[CacheOperation] = []
332
+ self.ack_load_queue: List[HiCacheAck] = []
333
+ self.ack_write_queue: List[HiCacheAck] = []
359
334
 
360
335
  self.stop_event = threading.Event()
361
336
  self.write_buffer = TransferBuffer(self.stop_event)
@@ -366,16 +341,6 @@ class HiCacheController:
366
341
  self.write_stream = torch.cuda.Stream()
367
342
  self.load_stream = torch.cuda.Stream()
368
343
 
369
- self.write_thread = threading.Thread(
370
- target=self.write_thread_func_direct, daemon=True
371
- )
372
- self.load_thread = threading.Thread(
373
- target=self.load_thread_func_layer_by_layer, daemon=True
374
- )
375
-
376
- self.write_thread.start()
377
- self.load_thread.start()
378
-
379
344
  if self.enable_storage:
380
345
  self.prefetch_thread = threading.Thread(
381
346
  target=self.prefetch_thread_func, daemon=True
@@ -396,49 +361,39 @@ class HiCacheController:
396
361
  def _generate_storage_config(
397
362
  self,
398
363
  model_name: Optional[str] = None,
399
- storage_backend_extra_config: Optional[str] = None,
364
+ storage_backend_extra_config: Optional[dict] = None,
400
365
  ):
401
366
 
402
367
  if is_dp_attention_enabled():
403
368
  self.tp_rank = get_attention_tp_rank()
404
369
  self.tp_size = get_attention_tp_size()
370
+ self.dp_rank = get_attention_dp_rank()
405
371
  else:
406
372
  self.tp_rank = get_tensor_model_parallel_rank()
407
373
  self.tp_size = get_tensor_model_parallel_world_size()
374
+ self.dp_rank = 0
408
375
 
409
376
  # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
410
377
  is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
411
378
 
412
- # Parse extra config JSON if provided
413
- extra_config = None
414
- if storage_backend_extra_config:
415
- try:
416
- import json
417
-
418
- extra_config = json.loads(storage_backend_extra_config)
419
- except Exception as e:
420
- logger.error(f"Invalid backend extra config JSON: {e}")
421
-
422
379
  return HiCacheStorageConfig(
423
380
  tp_rank=self.tp_rank,
424
381
  tp_size=self.tp_size,
425
382
  is_mla_model=is_mla_backend,
426
383
  is_page_first_layout=self.mem_pool_host.layout == "page_first",
427
384
  model_name=model_name,
428
- extra_config=extra_config,
385
+ extra_config=storage_backend_extra_config,
429
386
  )
430
387
 
431
388
  def reset(self):
432
389
  self.stop_event.set()
433
- self.write_thread.join()
434
- self.load_thread.join()
435
390
 
436
- self.write_queue.queue.clear()
437
- self.load_queue.queue.clear()
391
+ self.write_queue.clear()
392
+ self.load_queue.clear()
438
393
  self.write_buffer.clear()
439
394
  self.load_buffer.clear()
440
- self.ack_write_queue.queue.clear()
441
- self.ack_load_queue.queue.clear()
395
+ self.ack_write_queue.clear()
396
+ self.ack_load_queue.clear()
442
397
  if self.enable_storage:
443
398
  self.prefetch_thread.join()
444
399
  self.backup_thread.join()
@@ -447,15 +402,7 @@ class HiCacheController:
447
402
  self.prefetch_revoke_queue.queue.clear()
448
403
  self.ack_backup_queue.queue.clear()
449
404
 
450
- self.write_thread = threading.Thread(
451
- target=self.write_thread_func_direct, daemon=True
452
- )
453
- self.load_thread = threading.Thread(
454
- target=self.load_thread_func_layer_by_layer, daemon=True
455
- )
456
405
  self.stop_event.clear()
457
- self.write_thread.start()
458
- self.load_thread.start()
459
406
 
460
407
  if self.enable_storage:
461
408
  self.prefetch_thread = threading.Thread(
@@ -471,7 +418,7 @@ class HiCacheController:
471
418
  self,
472
419
  device_indices: torch.Tensor,
473
420
  priority: Optional[int] = None,
474
- node_id: int = 0,
421
+ node_id: int = -1,
475
422
  ) -> Optional[torch.Tensor]:
476
423
  """
477
424
  Back up KV caches from device memory to host memory.
@@ -479,18 +426,45 @@ class HiCacheController:
479
426
  host_indices = self.mem_pool_host.alloc(len(device_indices))
480
427
  if host_indices is None:
481
428
  return None
482
- self.mem_pool_host.protect_write(host_indices)
483
- torch.cuda.current_stream().synchronize()
484
- self.write_queue.put(
429
+ self.write_queue.append(
485
430
  CacheOperation(host_indices, device_indices, node_id, priority)
486
431
  )
432
+ self.start_writing()
487
433
  return host_indices
488
434
 
435
+ def start_writing(self) -> None:
436
+ if len(self.write_queue) == 0:
437
+ return
438
+
439
+ op = CacheOperation.merge_ops(self.write_queue)
440
+ host_indices, device_indices = self.move_indices(op)
441
+ self.write_queue.clear()
442
+
443
+ start_event = torch.cuda.Event()
444
+ finish_event = torch.cuda.Event()
445
+
446
+ start_event.record()
447
+ with torch.cuda.stream(self.write_stream):
448
+ start_event.wait(self.write_stream)
449
+ self.mem_pool_host.backup_from_device_all_layer(
450
+ self.mem_pool_device, host_indices, device_indices, self.io_backend
451
+ )
452
+ finish_event.record()
453
+ # NOTE: We must save the host indices and device indices here,
454
+ # this is because we need to guarantee that these tensors are
455
+ # still alive when the write stream is executing.
456
+ if host_indices.is_cuda:
457
+ host_indices.record_stream(self.write_stream)
458
+ if device_indices.is_cuda:
459
+ device_indices.record_stream(self.write_stream)
460
+
461
+ self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
462
+
489
463
  def load(
490
464
  self,
491
465
  host_indices: torch.Tensor,
492
466
  priority: Optional[int] = None,
493
- node_id: int = 0,
467
+ node_id: int = -1,
494
468
  ) -> Optional[torch.Tensor]:
495
469
  """
496
470
  Load KV caches from host memory to device memory.
@@ -498,77 +472,42 @@ class HiCacheController:
498
472
  device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
499
473
  if device_indices is None:
500
474
  return None
501
- self.mem_pool_host.protect_load(host_indices)
502
- # to ensure the device indices are ready before accessed by another CUDA stream
503
- torch.cuda.current_stream().synchronize()
504
- self.load_queue.put(
475
+ self.load_queue.append(
505
476
  CacheOperation(host_indices, device_indices, node_id, priority)
506
477
  )
507
478
  return device_indices
508
479
 
509
- def move_indices(self, host_indices, device_indices):
480
+ def move_indices(self, op: CacheOperation):
481
+ host_indices, device_indices = op.host_indices, op.device_indices
510
482
  # move indices to GPU if using kernels, to host if using direct indexing
511
483
  if self.io_backend == "kernel":
512
- return host_indices.to(self.mem_pool_device.device), device_indices
484
+ if not host_indices.is_cuda:
485
+ host_indices = host_indices.to(self.device, non_blocking=True)
486
+ return host_indices, device_indices
513
487
  elif self.io_backend == "direct":
514
- device_indices = device_indices.cpu()
515
- host_indices, idx = host_indices.sort()
516
- return host_indices, device_indices.index_select(0, idx)
488
+ if self.mem_pool_host.layout == "layer_first":
489
+ device_indices = device_indices.cpu()
490
+ host_indices, idx = host_indices.sort()
491
+ return host_indices, device_indices.index_select(0, idx)
492
+ elif self.mem_pool_host.layout == "page_first_direct":
493
+ return host_indices, device_indices.cpu()
517
494
  else:
518
495
  raise ValueError(f"Unsupported io backend")
519
496
 
520
- def write_thread_func_direct(self):
521
- """
522
- Directly write through KV caches to host memory without buffering.
523
- """
524
- torch.cuda.set_stream(self.write_stream)
525
- while not self.stop_event.is_set():
526
- try:
527
- operation = self.write_queue.get(block=True, timeout=1)
528
- host_indices, device_indices = self.move_indices(
529
- operation.host_indices, operation.device_indices
530
- )
531
- self.mem_pool_host.backup_from_device_all_layer(
532
- self.mem_pool_device, host_indices, device_indices, self.io_backend
533
- )
534
- self.write_stream.synchronize()
535
- self.mem_pool_host.complete_io(operation.host_indices)
536
- for node_id in operation.node_ids:
537
- if node_id != 0:
538
- self.ack_write_queue.put(node_id)
539
- except Empty:
540
- continue
541
- except Exception as e:
542
- logger.error(e)
497
+ def start_loading(self) -> int:
498
+ if len(self.load_queue) == 0:
499
+ return -1
543
500
 
544
- def load_thread_func_layer_by_layer(self):
545
- """
546
- Load KV caches from host memory to device memory layer by layer.
547
- """
548
- torch.cuda.set_stream(self.load_stream)
549
- while not self.stop_event.is_set():
550
- self.load_cache_event.wait(timeout=1)
551
- if not self.load_cache_event.is_set():
552
- continue
553
- self.load_cache_event.clear()
554
- self.layer_done_counter.update_producer()
555
-
556
- batch_operation = None
557
- while self.load_queue.qsize() > 0:
558
- op = self.load_queue.get(block=True)
559
- if batch_operation is None:
560
- batch_operation = op
561
- else:
562
- batch_operation.merge(op)
563
- if batch_operation is None:
564
- continue
501
+ producer_id = self.layer_done_counter.update_producer()
502
+ op = CacheOperation.merge_ops(self.load_queue)
503
+ host_indices, device_indices = self.move_indices(op)
504
+ self.load_queue.clear()
505
+ producer_event = self.layer_done_counter.events[producer_id]
506
+ producer_event.start_event.record()
565
507
 
566
- # start layer-wise KV cache transfer from CPU to GPU
567
- self.layer_done_counter.reset()
568
- host_indices, device_indices = self.move_indices(
569
- batch_operation.host_indices, batch_operation.device_indices
570
- )
571
- for i in range(self.mem_pool_host.layer_num):
508
+ with torch.cuda.stream(self.load_stream):
509
+ producer_event.start_event.wait(self.load_stream)
510
+ for i in range(self.layer_num):
572
511
  self.mem_pool_host.load_to_device_per_layer(
573
512
  self.mem_pool_device,
574
513
  host_indices,
@@ -576,37 +515,34 @@ class HiCacheController:
576
515
  i,
577
516
  self.io_backend,
578
517
  )
579
- self.load_stream.synchronize()
580
- self.layer_done_counter.increment()
581
-
582
- self.mem_pool_host.complete_io(batch_operation.host_indices)
583
- for node_id in batch_operation.node_ids:
584
- if node_id != 0:
585
- self.ack_load_queue.put(node_id)
586
-
587
- def evict_device(
588
- self, device_indices: torch.Tensor, host_indices: torch.Tensor
589
- ) -> int:
590
- if self.mem_pool_host.is_synced(host_indices):
591
- self.mem_pool_device_allocator.free(device_indices)
592
- self.mem_pool_host.update_backup(host_indices)
593
- return len(device_indices)
594
- else:
595
- raise ValueError(
596
- f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
518
+ producer_event.complete(i)
519
+ # NOTE: We must save the host indices and device indices here,
520
+ # this is because we need to guarantee that these tensors are
521
+ # still alive when the load stream is executing.
522
+ if host_indices.is_cuda:
523
+ host_indices.record_stream(self.load_stream)
524
+ if device_indices.is_cuda:
525
+ device_indices.record_stream(self.load_stream)
526
+
527
+ self.ack_load_queue.append(
528
+ HiCacheAck(
529
+ start_event=producer_event.start_event,
530
+ finish_event=producer_event.finish_event,
531
+ node_ids=op.node_ids,
597
532
  )
533
+ )
534
+ return producer_id
535
+
536
+ def evict_device(self, device_indices: torch.Tensor) -> int:
537
+ self.mem_pool_device_allocator.free(device_indices)
538
+ return len(device_indices)
598
539
 
599
540
  def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
600
541
  if not backup_only:
601
542
  raise ValueError("Other eviction policies are not supported yet.")
602
543
 
603
- if self.mem_pool_host.is_backup(host_indices):
604
- self.mem_pool_host.free(host_indices)
605
- return len(host_indices)
606
- else:
607
- raise ValueError(
608
- f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
609
- )
544
+ self.mem_pool_host.free(host_indices)
545
+ return len(host_indices)
610
546
 
611
547
  def prefetch(
612
548
  self,
@@ -625,50 +561,29 @@ class HiCacheController:
625
561
  return operation
626
562
 
627
563
  def terminate_prefetch(self, operation):
628
- operation.mark_done()
564
+ operation.mark_terminate()
629
565
  return operation.completed_tokens, operation.hash_value
630
566
 
631
567
  def append_host_mem_release(self, host_indices: torch.Tensor):
632
- chunks = host_indices.split(self.mem_pool_host.page_size)
633
- for chunk in chunks:
634
- self.host_mem_release_queue.put(chunk)
635
-
636
- def _3fs_zero_copy_batch_exists(self, batch_hashes):
637
- _batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
638
- hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
639
- return hit_page_num
640
-
641
- def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
642
- hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
643
- hash_values, host_indices
644
- )
645
- page_data = self.storage_backend.batch_get(hashes, dsts)
646
- if page_data:
647
- inc = self.page_size * len(hashes) // factor
648
- operation.increment(inc)
649
- else:
650
- logger.warning(
651
- f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
652
- )
568
+ if host_indices.numel() == 0:
569
+ return
570
+ pages = host_indices.split(self.mem_pool_host.page_size)
571
+ for page in pages:
572
+ self.host_mem_release_queue.put(page)
653
573
 
654
- def _mooncake_page_get(self, operation, hash_values, host_indices):
655
- key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
656
- hash_values,
657
- host_indices,
658
- self.storage_config.tp_rank,
659
- )
660
- get_result = self.storage_backend.batch_get(
661
- key_strs,
662
- target_locations=buffer_ptrs,
663
- target_sizes=buffer_sizes,
664
- )
665
- if get_result != len(hash_values):
666
- logger.warning(
667
- f"Prefetch operation {operation.request_id} failed or partially failed."
668
- )
669
- if get_result != 0:
670
- operation.increment(get_result * self.page_size)
574
+ def _page_get_zero_copy(self, operation, hash_values, host_indices):
575
+ results = self.storage_backend.batch_get_v1(hash_values, host_indices)
576
+ inc = 0
577
+ for i in range(len(hash_values)):
578
+ if not results[i]:
579
+ logger.warning(
580
+ f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
581
+ )
582
+ break
583
+ inc += self.page_size
584
+ operation.increment(inc)
671
585
 
586
+ # todo: deprecate
672
587
  def _generic_page_get(self, operation, hash_values, host_indices):
673
588
  dummy_page_dst = [
674
589
  self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
@@ -706,6 +621,7 @@ class HiCacheController:
706
621
  operation.completed_tokens
707
622
  != prev_completed_tokens + len(batch_hashes) * self.page_size
708
623
  ):
624
+ operation.mark_terminate()
709
625
  break # Some operations fail or operation terminated by controller
710
626
  # release pre-allocated memory
711
627
  self.append_host_mem_release(
@@ -757,7 +673,7 @@ class HiCacheController:
757
673
  batch_tokens[i : i + self.page_size], last_hash
758
674
  )
759
675
  batch_hashes.append(last_hash)
760
- hit_page_num = self.batch_exists_func(batch_hashes)
676
+ hit_page_num = self.storage_backend.batch_exists(batch_hashes)
761
677
  hash_value.extend(batch_hashes[:hit_page_num])
762
678
  storage_query_count += hit_page_num * self.page_size
763
679
  if hit_page_num < len(batch_hashes):
@@ -826,34 +742,16 @@ class HiCacheController:
826
742
  self.backup_queue.put(operation)
827
743
  return operation.id
828
744
 
829
- # non-zero copy
745
+ # todo: deprecate
830
746
  def _generic_page_set(self, hash_values, host_indices) -> bool:
831
747
  data = [
832
- self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
748
+ self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
833
749
  for i in range(len(hash_values))
834
750
  ]
835
751
  return self.storage_backend.batch_set(hash_values, data)
836
752
 
837
- # zero copy
838
- def _mooncake_page_set(self, hash_values, host_indices) -> bool:
839
- key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
840
- hash_values,
841
- host_indices,
842
- self.storage_config.tp_rank,
843
- )
844
- success = self.storage_backend.batch_set(
845
- key_strs,
846
- target_locations=buffer_ptrs,
847
- target_sizes=buffer_sizes,
848
- )
849
- return success
850
-
851
- # zero copy
852
- def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
853
- hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
854
- hash_values, host_indices
855
- )
856
- return self.storage_backend.batch_set(hashes, dsts)
753
+ def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
754
+ return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
857
755
 
858
756
  # Backup batch by batch
859
757
  def _page_backup(self, operation):
@@ -885,7 +783,7 @@ class HiCacheController:
885
783
 
886
784
  if not self.backup_skip:
887
785
  self._page_backup(operation)
888
- self.ack_backup_queue.put(operation.id)
786
+ self.ack_backup_queue.put(operation)
889
787
 
890
788
  except Empty:
891
789
  continue