sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -14,15 +14,17 @@ limitations under the License.
14
14
  """
15
15
 
16
16
  import logging
17
- import math
18
17
  import threading
19
18
  import time
20
- from queue import Empty, Full, PriorityQueue, Queue
21
- from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
19
+ from queue import Empty, Full, Queue
20
+ from typing import TYPE_CHECKING, List, NamedTuple, Optional
22
21
 
23
22
  import torch
24
23
 
25
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
24
+ from sglang.srt.mem_cache.hicache_storage import (
25
+ HiCacheStorageConfig,
26
+ HiCacheStorageExtraInfo,
27
+ )
26
28
 
27
29
  if TYPE_CHECKING:
28
30
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
@@ -38,7 +40,7 @@ from sglang.srt.layers.dp_attention import (
38
40
  get_attention_tp_size,
39
41
  is_dp_attention_enabled,
40
42
  )
41
- from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
43
+ from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
42
44
 
43
45
  logger = logging.getLogger(__name__)
44
46
 
@@ -191,12 +193,14 @@ class StorageOperation:
191
193
  token_ids: List[int],
192
194
  last_hash: Optional[str] = None,
193
195
  hash_value: Optional[List[str]] = None,
196
+ prefix_keys: Optional[List[str]] = None,
194
197
  ):
195
198
  self.host_indices = host_indices
196
199
  self.token_ids = token_ids
197
200
  self.last_hash = last_hash
198
201
  self.completed_tokens = 0
199
202
  self.hash_value = hash_value if hash_value is not None else []
203
+ self.prefix_keys = prefix_keys
200
204
 
201
205
  self.id = StorageOperation.counter
202
206
  StorageOperation.counter += 1
@@ -212,6 +216,7 @@ class PrefetchOperation(StorageOperation):
212
216
  host_indices: torch.Tensor,
213
217
  token_ids: List[int],
214
218
  last_hash: Optional[str] = None,
219
+ prefix_keys: Optional[List[str]] = None,
215
220
  ):
216
221
  self.request_id = request_id
217
222
 
@@ -219,7 +224,7 @@ class PrefetchOperation(StorageOperation):
219
224
  self._terminated_flag = False
220
225
  self.start_time = time.monotonic()
221
226
 
222
- super().__init__(host_indices, token_ids, last_hash)
227
+ super().__init__(host_indices, token_ids, last_hash, prefix_keys=prefix_keys)
223
228
 
224
229
  def increment(self, num_tokens: int):
225
230
  with self._lock:
@@ -250,7 +255,7 @@ class HiCacheController:
250
255
  storage_backend: Optional[str] = None,
251
256
  prefetch_threshold: int = 256,
252
257
  model_name: Optional[str] = None,
253
- storage_backend_extra_config: Optional[str] = None,
258
+ storage_backend_extra_config: Optional[dict] = None,
254
259
  ):
255
260
  self.mem_pool_device_allocator = token_to_kv_pool_allocator
256
261
  self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
@@ -275,43 +280,17 @@ class HiCacheController:
275
280
  and self.storage_config.tp_rank != 0
276
281
  )
277
282
 
278
- if storage_backend == "file":
279
- from sglang.srt.mem_cache.hicache_storage import HiCacheFile
280
-
281
- self.storage_backend = HiCacheFile(self.storage_config)
282
- elif storage_backend == "nixl":
283
- from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
283
+ # Use storage backend factory for dynamic backend creation
284
+ from sglang.srt.mem_cache.storage import StorageBackendFactory
284
285
 
285
- self.storage_backend = HiCacheNixl()
286
- elif storage_backend == "mooncake":
287
- from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
288
- MooncakeStore,
289
- )
290
-
291
- self.storage_backend = MooncakeStore(self.storage_config)
292
- self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
293
- assert self.mem_pool_host.layout == "page_first"
294
- elif storage_backend == "hf3fs":
295
- from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
296
- HiCacheHF3FS,
286
+ try:
287
+ self.storage_backend = StorageBackendFactory.create_backend(
288
+ storage_backend, self.storage_config, self.mem_pool_host
297
289
  )
290
+ except ValueError as e:
291
+ raise ValueError(f"Failed to create storage backend: {e}") from e
298
292
 
299
- if self.mem_pool_host.layout == "page_first":
300
- bytes_per_page = (
301
- mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
302
- )
303
- elif self.mem_pool_host.layout == "layer_first":
304
- bytes_per_page = (
305
- mem_pool_host.get_size_per_token() * mem_pool_host.page_size
306
- )
307
- dtype = mem_pool_host.dtype
308
- self.storage_backend = HiCacheHF3FS.from_env_config(
309
- bytes_per_page, dtype, self.storage_config
310
- )
311
- else:
312
- raise NotImplementedError(
313
- f"Unsupported storage backend: {storage_backend}"
314
- )
293
+ self.storage_backend.register_mem_pool_host(self.mem_pool_host)
315
294
 
316
295
  self.enable_storage = True
317
296
  # todo: threshold policy for prefetching
@@ -335,18 +314,10 @@ class HiCacheController:
335
314
  # Select the get and set functions
336
315
  self.page_get_func = self._generic_page_get
337
316
  self.page_set_func = self._generic_page_set
338
- self.batch_exists_func = self.storage_backend.batch_exists
339
- self.is_3fs_zerocopy = (
340
- self.storage_backend_type == "hf3fs"
341
- and self.mem_pool_host.layout == "page_first"
342
- )
343
- if self.storage_backend_type == "mooncake":
344
- self.page_get_func = self._mooncake_page_get
345
- self.page_set_func = self._mooncake_page_set
346
- elif self.is_3fs_zerocopy:
347
- self.page_get_func = self._3fs_zero_copy_page_get
348
- self.page_set_func = self._3fs_zero_copy_page_set
349
- self.batch_exists_func = self._3fs_zero_copy_batch_exists
317
+
318
+ if self.storage_backend_type in ["hf3fs", "mooncake", "eic"]:
319
+ self.page_get_func = self._page_get_zero_copy
320
+ self.page_set_func = self._page_set_zero_copy
350
321
 
351
322
  self.device = self.mem_pool_device.device
352
323
  self.layer_num = self.mem_pool_device.layer_num
@@ -395,7 +366,7 @@ class HiCacheController:
395
366
  def _generate_storage_config(
396
367
  self,
397
368
  model_name: Optional[str] = None,
398
- storage_backend_extra_config: Optional[str] = None,
369
+ storage_backend_extra_config: Optional[dict] = None,
399
370
  ):
400
371
 
401
372
  if is_dp_attention_enabled():
@@ -410,23 +381,13 @@ class HiCacheController:
410
381
  # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
411
382
  is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
412
383
 
413
- # Parse extra config JSON if provided
414
- extra_config = None
415
- if storage_backend_extra_config:
416
- try:
417
- import json
418
-
419
- extra_config = json.loads(storage_backend_extra_config)
420
- except Exception as e:
421
- logger.error(f"Invalid backend extra config JSON: {e}")
422
-
423
384
  return HiCacheStorageConfig(
424
385
  tp_rank=self.tp_rank,
425
386
  tp_size=self.tp_size,
426
387
  is_mla_model=is_mla_backend,
427
388
  is_page_first_layout=self.mem_pool_host.layout == "page_first",
428
389
  model_name=model_name,
429
- extra_config=extra_config,
390
+ extra_config=storage_backend_extra_config,
430
391
  )
431
392
 
432
393
  def reset(self):
@@ -470,7 +431,6 @@ class HiCacheController:
470
431
  host_indices = self.mem_pool_host.alloc(len(device_indices))
471
432
  if host_indices is None:
472
433
  return None
473
- self.mem_pool_host.protect_write(host_indices)
474
434
  self.write_queue.append(
475
435
  CacheOperation(host_indices, device_indices, node_id, priority)
476
436
  )
@@ -494,7 +454,6 @@ class HiCacheController:
494
454
  self.mem_pool_host.backup_from_device_all_layer(
495
455
  self.mem_pool_device, host_indices, device_indices, self.io_backend
496
456
  )
497
- self.mem_pool_host.complete_io(op.host_indices)
498
457
  finish_event.record()
499
458
  # NOTE: We must save the host indices and device indices here,
500
459
  # this is because we need to guarantee that these tensors are
@@ -518,7 +477,6 @@ class HiCacheController:
518
477
  device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
519
478
  if device_indices is None:
520
479
  return None
521
- self.mem_pool_host.protect_load(host_indices)
522
480
  self.load_queue.append(
523
481
  CacheOperation(host_indices, device_indices, node_id, priority)
524
482
  )
@@ -563,7 +521,6 @@ class HiCacheController:
563
521
  self.io_backend,
564
522
  )
565
523
  producer_event.complete(i)
566
- self.mem_pool_host.complete_io(op.host_indices)
567
524
  # NOTE: We must save the host indices and device indices here,
568
525
  # this is because we need to guarantee that these tensors are
569
526
  # still alive when the load stream is executing.
@@ -581,29 +538,16 @@ class HiCacheController:
581
538
  )
582
539
  return producer_id
583
540
 
584
- def evict_device(
585
- self, device_indices: torch.Tensor, host_indices: torch.Tensor
586
- ) -> int:
587
- if self.mem_pool_host.is_synced(host_indices):
588
- self.mem_pool_device_allocator.free(device_indices)
589
- self.mem_pool_host.update_backup(host_indices)
590
- return len(device_indices)
591
- else:
592
- raise ValueError(
593
- f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
594
- )
541
+ def evict_device(self, device_indices: torch.Tensor) -> int:
542
+ self.mem_pool_device_allocator.free(device_indices)
543
+ return len(device_indices)
595
544
 
596
545
  def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
597
546
  if not backup_only:
598
547
  raise ValueError("Other eviction policies are not supported yet.")
599
548
 
600
- if self.mem_pool_host.is_backup(host_indices):
601
- self.mem_pool_host.free(host_indices)
602
- return len(host_indices)
603
- else:
604
- raise ValueError(
605
- f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
606
- )
549
+ self.mem_pool_host.free(host_indices)
550
+ return len(host_indices)
607
551
 
608
552
  def prefetch(
609
553
  self,
@@ -611,12 +555,13 @@ class HiCacheController:
611
555
  host_indices: torch.Tensor,
612
556
  new_input_tokens: List[int],
613
557
  last_hash: Optional[str] = None,
558
+ prefix_keys: Optional[List[str]] = None,
614
559
  ) -> PrefetchOperation:
615
560
  """
616
561
  Prefetch KV caches from storage backend to host memory.
617
562
  """
618
563
  operation = PrefetchOperation(
619
- request_id, host_indices, new_input_tokens, last_hash
564
+ request_id, host_indices, new_input_tokens, last_hash, prefix_keys
620
565
  )
621
566
  self.prefetch_queue.put(operation)
622
567
  return operation
@@ -626,47 +571,30 @@ class HiCacheController:
626
571
  return operation.completed_tokens, operation.hash_value
627
572
 
628
573
  def append_host_mem_release(self, host_indices: torch.Tensor):
629
- chunks = host_indices.split(self.mem_pool_host.page_size)
630
- for chunk in chunks:
631
- self.host_mem_release_queue.put(chunk)
632
-
633
- def _3fs_zero_copy_batch_exists(self, batch_hashes):
634
- _batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
635
- hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
636
- return hit_page_num
637
-
638
- def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
639
- hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
640
- hash_values, host_indices
641
- )
642
- page_data = self.storage_backend.batch_get(hashes, dsts)
643
- if page_data:
644
- inc = self.page_size * len(hashes) // factor
645
- operation.increment(inc)
646
- else:
647
- logger.warning(
648
- f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
649
- )
574
+ if host_indices.numel() == 0:
575
+ return
576
+ pages = host_indices.split(self.mem_pool_host.page_size)
577
+ for page in pages:
578
+ self.host_mem_release_queue.put(page)
650
579
 
651
- def _mooncake_page_get(self, operation, hash_values, host_indices):
652
- key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
653
- hash_values,
654
- host_indices,
655
- self.storage_config.tp_rank,
656
- )
657
- get_result = self.storage_backend.batch_get(
658
- key_strs,
659
- target_locations=buffer_ptrs,
660
- target_sizes=buffer_sizes,
580
+ def _page_get_zero_copy(
581
+ self, operation, hash_values, host_indices, extra_info=None
582
+ ):
583
+ results = self.storage_backend.batch_get_v1(
584
+ hash_values, host_indices, extra_info
661
585
  )
662
- if get_result != len(hash_values):
663
- logger.warning(
664
- f"Prefetch operation {operation.request_id} failed or partially failed."
665
- )
666
- if get_result != 0:
667
- operation.increment(get_result * self.page_size)
586
+ inc = 0
587
+ for i in range(len(hash_values)):
588
+ if not results[i]:
589
+ logger.warning(
590
+ f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
591
+ )
592
+ break
593
+ inc += self.page_size
594
+ operation.increment(inc)
668
595
 
669
- def _generic_page_get(self, operation, hash_values, host_indices):
596
+ # todo: deprecate
597
+ def _generic_page_get(self, operation, hash_values, host_indices, extra_info=None):
670
598
  dummy_page_dst = [
671
599
  self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
672
600
  ]
@@ -690,6 +618,7 @@ class HiCacheController:
690
618
 
691
619
  def _page_transfer(self, operation):
692
620
  # Transfer batch by batch
621
+ prefix_keys = operation.prefix_keys
693
622
  for i in range(0, len(operation.hash_value), self.storage_batch_size):
694
623
  batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
695
624
  batch_host_indices = operation.host_indices[
@@ -697,7 +626,8 @@ class HiCacheController:
697
626
  ]
698
627
  prev_completed_tokens = operation.completed_tokens
699
628
  # Get one batch token, and update the completed_tokens if succeed
700
- self.page_get_func(operation, batch_hashes, batch_host_indices)
629
+ extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
630
+ self.page_get_func(operation, batch_hashes, batch_host_indices, extra_info)
701
631
  # Check termination
702
632
  if (
703
633
  operation.completed_tokens
@@ -705,6 +635,10 @@ class HiCacheController:
705
635
  ):
706
636
  operation.mark_terminate()
707
637
  break # Some operations fail or operation terminated by controller
638
+
639
+ if prefix_keys and len(prefix_keys) > 0:
640
+ prefix_keys += batch_hashes
641
+
708
642
  # release pre-allocated memory
709
643
  self.append_host_mem_release(
710
644
  operation.host_indices[operation.completed_tokens :]
@@ -738,6 +672,7 @@ class HiCacheController:
738
672
  def _storage_hit_query(self, operation) -> tuple[list[str], int]:
739
673
  last_hash = operation.last_hash
740
674
  tokens_to_fetch = operation.token_ids
675
+ prefix_keys = operation.prefix_keys.copy() if operation.prefix_keys else None
741
676
 
742
677
  storage_query_count = 0
743
678
  hash_value = []
@@ -755,11 +690,15 @@ class HiCacheController:
755
690
  batch_tokens[i : i + self.page_size], last_hash
756
691
  )
757
692
  batch_hashes.append(last_hash)
758
- hit_page_num = self.batch_exists_func(batch_hashes)
693
+ extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
694
+ hit_page_num = self.storage_backend.batch_exists(batch_hashes, extra_info)
759
695
  hash_value.extend(batch_hashes[:hit_page_num])
760
696
  storage_query_count += hit_page_num * self.page_size
761
697
  if hit_page_num < len(batch_hashes):
762
698
  break
699
+ if prefix_keys and len(prefix_keys) > 0:
700
+ prefix_keys += batch_hashes
701
+
763
702
  return hash_value, storage_query_count
764
703
 
765
704
  def prefetch_thread_func(self):
@@ -816,46 +755,34 @@ class HiCacheController:
816
755
  host_indices: torch.Tensor,
817
756
  token_ids: List[int],
818
757
  hash_value: Optional[List[str]] = None,
758
+ prefix_keys: Optional[List[str]] = None,
819
759
  ) -> int:
820
760
  """
821
761
  Write KV caches from host memory to storage backend.
822
762
  """
823
- operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
763
+ operation = StorageOperation(
764
+ host_indices, token_ids, hash_value=hash_value, prefix_keys=prefix_keys
765
+ )
824
766
  self.backup_queue.put(operation)
825
767
  return operation.id
826
768
 
827
- # non-zero copy
828
- def _generic_page_set(self, hash_values, host_indices) -> bool:
769
+ # todo: deprecate
770
+ def _generic_page_set(self, hash_values, host_indices, extra_info=None) -> bool:
829
771
  data = [
830
- self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
772
+ self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
831
773
  for i in range(len(hash_values))
832
774
  ]
833
775
  return self.storage_backend.batch_set(hash_values, data)
834
776
 
835
- # zero copy
836
- def _mooncake_page_set(self, hash_values, host_indices) -> bool:
837
- key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
838
- hash_values,
839
- host_indices,
840
- self.storage_config.tp_rank,
841
- )
842
- success = self.storage_backend.batch_set(
843
- key_strs,
844
- target_locations=buffer_ptrs,
845
- target_sizes=buffer_sizes,
777
+ def _page_set_zero_copy(self, hash_values, host_indices, extra_info=None) -> bool:
778
+ return all(
779
+ self.storage_backend.batch_set_v1(hash_values, host_indices, extra_info)
846
780
  )
847
- return success
848
-
849
- # zero copy
850
- def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
851
- hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
852
- hash_values, host_indices
853
- )
854
- return self.storage_backend.batch_set(hashes, dsts)
855
781
 
856
782
  # Backup batch by batch
857
783
  def _page_backup(self, operation):
858
784
  # Backup batch by batch
785
+ prefix_keys = operation.prefix_keys
859
786
  for i in range(0, len(operation.hash_value), self.storage_batch_size):
860
787
  batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
861
788
  batch_host_indices = operation.host_indices[
@@ -863,12 +790,16 @@ class HiCacheController:
863
790
  ]
864
791
  # Set one batch token, and record if success.
865
792
  # todo: allow partial success
866
- success = self.page_set_func(batch_hashes, batch_host_indices)
793
+ extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
794
+ success = self.page_set_func(batch_hashes, batch_host_indices, extra_info)
867
795
  if not success:
868
796
  logger.warning(
869
797
  f"Write page to storage: {len(batch_hashes)} pages failed."
870
798
  )
871
799
  break
800
+
801
+ if prefix_keys and len(prefix_keys) > 0:
802
+ prefix_keys += batch_hashes
872
803
  operation.completed_tokens += self.page_size * len(batch_hashes)
873
804
 
874
805
  def backup_thread_func(self):