sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,284 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import threading
5
+ from typing import TYPE_CHECKING, List, Optional
6
+
7
+ import torch
8
+
9
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
10
+ from sglang.srt.mem_cache.base_prefix_cache import MatchResult
11
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
12
+ from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
13
+
14
+ try:
15
+ from lmcache.integration.sglang.sglang_adapter import (
16
+ LMCacheLayerwiseConnector,
17
+ LoadMetadata,
18
+ StoreMetadata,
19
+ )
20
+ except ImportError as e:
21
+ raise RuntimeError(
22
+ "LMCache is not installed. Please install it by running `pip install lmcache`"
23
+ ) from e
24
+
25
+ if TYPE_CHECKING:
26
+ from sglang.srt.configs.model_config import ModelConfig
27
+ from sglang.srt.managers.schedule_batch import Req
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class LayerTransferCounter:
33
+ """Minimal adapter that lets the memory pool notify LMCache per-layer.
34
+
35
+ The KV pool calls `wait_until(layer_id)` after finishing a layer, which we
36
+ translate into a `load_kv_layerwise(layer_id)` call on the LMCache connector
37
+ within the provided CUDA stream.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ num_layers: int,
43
+ load_stream: torch.cuda.Stream,
44
+ lmc_connector: LMCacheLayerwiseConnector,
45
+ printable: bool = False,
46
+ ):
47
+ self.num_layers = num_layers
48
+ self.load_stream = load_stream
49
+ self.lmc_connector = lmc_connector
50
+
51
+ def wait_until(self, layer_id: int):
52
+ # Ensure ordering of the async loads wrt compute stream(s).
53
+ self.load_stream.synchronize()
54
+ with self.load_stream:
55
+ self.lmc_connector.load_kv_layerwise(layer_id)
56
+
57
+
58
+ class LMCRadixCache(RadixCache):
59
+ """RadixCache + LMCache IO.
60
+
61
+ This subclass adds:
62
+ - LMCache connector setup (device/host buffers, TP rank/size)
63
+ - Two CUDA streams for async load/store
64
+ - Layer-wise transfer executor wiring to the KV cache
65
+ - Overridden `match_prefix` to fetch missing prefix chunks from LMCache
66
+ - Extended cache_finalization paths to store back into LMCache
67
+ - Eviction barrier that respects any in-flight host->device stores
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ req_to_token_pool: ReqToTokenPool,
73
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
74
+ page_size: int,
75
+ disable: bool = False,
76
+ enable_kv_cache_events: bool = False,
77
+ model_config: Optional["ModelConfig"] = None,
78
+ tp_size: int = 1,
79
+ rank: int = 0,
80
+ tp_group: Optional[torch.distributed.ProcessGroup] = None,
81
+ eviction_policy: str = "lru",
82
+ ):
83
+ super().__init__(
84
+ req_to_token_pool=req_to_token_pool,
85
+ token_to_kv_pool_allocator=token_to_kv_pool_allocator,
86
+ page_size=page_size,
87
+ disable=disable,
88
+ enable_kv_cache_events=enable_kv_cache_events,
89
+ eviction_policy=eviction_policy,
90
+ )
91
+
92
+ kvcache = self.token_to_kv_pool_allocator.get_kvcache()
93
+ self.lmcache_connector = LMCacheLayerwiseConnector(
94
+ sgl_config=model_config,
95
+ tp_size=tp_size,
96
+ rank=rank,
97
+ # NOTE: The original implementation accessed private buffers via
98
+ # `_kvcache.k_buffer` / `.v_buffer`. We prefer public accessors when
99
+ # available; fall back to private fields if needed.
100
+ k_pool=getattr(
101
+ kvcache,
102
+ "k_buffer",
103
+ getattr(self.token_to_kv_pool_allocator._kvcache, "k_buffer"),
104
+ ),
105
+ v_pool=getattr(
106
+ kvcache,
107
+ "v_buffer",
108
+ getattr(self.token_to_kv_pool_allocator._kvcache, "v_buffer"),
109
+ ),
110
+ tp_group=tp_group,
111
+ )
112
+
113
+ self.load_stream = torch.cuda.Stream()
114
+ self.store_stream = torch.cuda.Stream()
115
+
116
+ self.layer_done_executor = LayerTransferCounter(
117
+ num_layers=(
118
+ model_config.num_hidden_layers if model_config is not None else 0
119
+ ),
120
+ load_stream=self.load_stream,
121
+ lmc_connector=self.lmcache_connector,
122
+ )
123
+ kvcache.register_layer_transfer_counter(self.layer_done_executor)
124
+
125
+ self._in_flight_nodes: list[TreeNode] = []
126
+ self._node_lock = threading.Lock()
127
+
128
+ def reset(self): # type: ignore[override]
129
+ super().reset()
130
+ if hasattr(self, "_in_flight_nodes"):
131
+ with self._node_lock:
132
+ self._in_flight_nodes.clear()
133
+
134
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # type: ignore[override]
135
+ """Match cached prefix; if there's a tail miss, prefetch from LMCache.
136
+
137
+ Reuses the base matching logic to obtain (value, last_node). If there
138
+ remains a *page-aligned* uncached suffix and there is room (or after
139
+ eviction), we allocate token slots and trigger an async LMCache load
140
+ into those slots, then materialize a new child node for the retrieved
141
+ chunk.
142
+ """
143
+ if self.disable or not key:
144
+ return super().match_prefix(key, **kwargs)
145
+
146
+ if self.page_size != 1:
147
+ aligned_len = len(key) // self.page_size * self.page_size
148
+ key = key[:aligned_len]
149
+
150
+ base_res = super().match_prefix(key, **kwargs)
151
+ value: torch.Tensor = base_res.device_indices
152
+ last_node: TreeNode = base_res.last_device_node
153
+
154
+ if value.numel() == len(key):
155
+ return base_res
156
+
157
+ uncached_len = len(key) - value.numel()
158
+ if uncached_len == 0:
159
+ return base_res
160
+
161
+ chunk_size = self.lmcache_connector.chunk_size()
162
+ prefix_pad = value.numel() % chunk_size
163
+
164
+ if self.token_to_kv_pool_allocator.available_size() < uncached_len:
165
+ self.evict(uncached_len)
166
+
167
+ token_slots = self.token_to_kv_pool_allocator.alloc(uncached_len)
168
+ if token_slots is None:
169
+ return base_res
170
+
171
+ slot_mapping = torch.cat(
172
+ [
173
+ torch.full((value.numel(),), -1, dtype=torch.int64, device=self.device),
174
+ token_slots.detach().clone().to(torch.int64).to(self.device),
175
+ ]
176
+ )
177
+
178
+ with torch.cuda.stream(self.load_stream):
179
+ num_retrieved = self.lmcache_connector.start_load_kv(
180
+ LoadMetadata(
181
+ token_ids=key.token_ids, # full page-aligned key
182
+ slot_mapping=slot_mapping,
183
+ offset=value.numel() - prefix_pad, # LMCache offset convention
184
+ )
185
+ )
186
+ logger.debug("num_retrieved_tokens: %s", num_retrieved)
187
+
188
+ if num_retrieved > 0:
189
+ self.token_to_kv_pool_allocator.free(
190
+ token_slots[(num_retrieved - prefix_pad) :]
191
+ )
192
+ else:
193
+ self.token_to_kv_pool_allocator.free(token_slots)
194
+
195
+ if num_retrieved > 0:
196
+ fetched = num_retrieved - prefix_pad
197
+ new_node = TreeNode()
198
+ start = value.numel()
199
+ end = start + fetched
200
+ new_node.key = key[start:end]
201
+ new_node.value = token_slots[:fetched]
202
+ new_node.parent = last_node
203
+ last_node.children[self.get_child_key_fn(new_node.key)] = new_node
204
+ last_node = new_node
205
+
206
+ value = torch.cat([value, token_slots[:fetched]])
207
+ self.evictable_size_ += fetched
208
+
209
+ self._record_store_event(new_node.parent)
210
+ self._record_store_event(new_node)
211
+
212
+ return MatchResult(
213
+ device_indices=value,
214
+ last_device_node=last_node,
215
+ last_host_node=last_node,
216
+ )
217
+
218
+ return base_res
219
+
220
+ def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
221
+ """On request completion, insert device KV into radix and store to LMCache."""
222
+
223
+ super().cache_finished_req(req)
224
+
225
+ token_ids = (req.origin_input_ids + req.output_ids)[:-1]
226
+ kv_indices = self.req_to_token_pool.req_to_token[
227
+ req.req_pool_idx, : len(token_ids)
228
+ ]
229
+
230
+ _, new_last_node, _, _ = self.match_prefix(RadixKey(token_ids, req.extra_key))
231
+ assert new_last_node is not None
232
+
233
+ self.inc_lock_ref(new_last_node)
234
+ store_md = StoreMetadata(
235
+ last_node=new_last_node,
236
+ token_ids=token_ids,
237
+ kv_indices=kv_indices,
238
+ offset=0,
239
+ )
240
+ with torch.cuda.stream(self.store_stream):
241
+ self.lmcache_connector.store_kv(store_md)
242
+ with self._node_lock:
243
+ self._in_flight_nodes.append(new_last_node)
244
+
245
+ def evict(self, num_tokens: int) -> None: # type: ignore[override]
246
+ """Before base eviction, wait for any outstanding stores and release locks."""
247
+ if self.disable:
248
+ return
249
+
250
+ self.store_stream.synchronize()
251
+ with self._node_lock:
252
+ for node in self._in_flight_nodes:
253
+ self.dec_lock_ref(node)
254
+ self._in_flight_nodes.clear()
255
+
256
+ super().evict(num_tokens)
257
+
258
+ def pretty_print(self): # type: ignore[override]
259
+ super().pretty_print()
260
+ try:
261
+ logger.debug(
262
+ "evictable=%d protected=%d", self.evictable_size_, self.protected_size_
263
+ )
264
+ except Exception: # pragma: no cover
265
+ pass
266
+
267
+
268
+ if __name__ == "__main__":
269
+ cache = LMCRadixCache(
270
+ req_to_token_pool=None,
271
+ token_to_kv_pool_allocator=None,
272
+ page_size=1,
273
+ disable=False,
274
+ enable_kv_cache_events=False,
275
+ model_config=None,
276
+ tp_size=1,
277
+ rank=0,
278
+ tp_group=None,
279
+ )
280
+ cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 11, 12], dtype=torch.int64))
281
+ cache.insert(
282
+ RadixKey([1, 2, 3, 4]), torch.tensor([10, 11, 12, 13], dtype=torch.int64)
283
+ )
284
+ cache.pretty_print()
@@ -0,0 +1,121 @@
1
+ try:
2
+ from lmcache.integration.sglang.sglang_adapter import (
3
+ LMCacheLayerwiseConnector,
4
+ LoadMetadata,
5
+ StoreMetadata,
6
+ )
7
+ except ImportError:
8
+ raise RuntimeError(
9
+ "LMCache is not installed. Please install it by running `pip install lmcache` in the root directory of LMCache"
10
+ )
11
+
12
+ import os
13
+
14
+ import torch
15
+
16
+ from sglang.srt.configs.model_config import ModelConfig
17
+
18
+ os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
19
+ os.environ["LMCACHE_CONFIG_FILE"] = "example_config.yaml"
20
+
21
+
22
+ def test_load_store_metadata():
23
+ model_config = ModelConfig(
24
+ model_path="Qwen/Qwen3-4B",
25
+ )
26
+
27
+ # Generate Dummy KV Cache
28
+ head_num = model_config.num_key_value_heads
29
+ head_dim = model_config.head_dim
30
+ layer_num = model_config.num_hidden_layers
31
+ buffer_size = 256
32
+ input_id_len = 16
33
+
34
+ k_buffer = [
35
+ torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
36
+ for _ in range(layer_num)
37
+ ]
38
+ v_buffer = [
39
+ torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
40
+ for _ in range(layer_num)
41
+ ]
42
+
43
+ connector = LMCacheLayerwiseConnector(model_config, 1, 0, k_buffer, v_buffer)
44
+
45
+ fake_token_ids = torch.randint(0, model_config.vocab_size, (input_id_len,)).tolist()
46
+ fake_kv_indices = torch.randint(0, buffer_size, (input_id_len,))
47
+ offset = 0
48
+
49
+ store_metadata = StoreMetadata(
50
+ last_node=None,
51
+ token_ids=fake_token_ids,
52
+ kv_indices=fake_kv_indices,
53
+ offset=offset,
54
+ )
55
+
56
+ load_metadata = LoadMetadata(
57
+ token_ids=fake_token_ids,
58
+ slot_mapping=fake_kv_indices,
59
+ offset=offset,
60
+ )
61
+
62
+ current_stream = torch.cuda.current_stream()
63
+
64
+ retrieve_token_num = connector.start_load_kv(load_metadata)
65
+ assert retrieve_token_num == 0
66
+
67
+ connector.store_kv(store_metadata)
68
+ current_stream.synchronize()
69
+
70
+ # check retrieve
71
+ gt_key_buffer = [
72
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
73
+ for _ in range(layer_num)
74
+ ]
75
+ gt_value_buffer = [
76
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
77
+ for _ in range(layer_num)
78
+ ]
79
+
80
+ for i in range(layer_num):
81
+ gt_key_buffer[i] = k_buffer[i][fake_kv_indices]
82
+ gt_value_buffer[i] = v_buffer[i][fake_kv_indices]
83
+
84
+ # clear the k_buffer and v_buffer
85
+ for _ in range(layer_num):
86
+ k_buffer[i].zero_()
87
+ v_buffer[i].zero_()
88
+
89
+ retrieve_token_num = connector.start_load_kv(load_metadata)
90
+ assert retrieve_token_num == input_id_len
91
+
92
+ for i in range(layer_num):
93
+ current_stream.synchronize()
94
+ connector.load_kv_layerwise(i)
95
+
96
+ current_stream.synchronize()
97
+ test_key_buffer = [
98
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
99
+ for _ in range(layer_num)
100
+ ]
101
+ test_value_buffer = [
102
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
103
+ for _ in range(layer_num)
104
+ ]
105
+
106
+ for i in range(layer_num):
107
+ test_key_buffer[i] = k_buffer[i][fake_kv_indices]
108
+ test_value_buffer[i] = v_buffer[i][fake_kv_indices]
109
+
110
+ for i in range(layer_num):
111
+ assert torch.allclose(test_key_buffer[i], gt_key_buffer[i])
112
+ assert torch.allclose(test_value_buffer[i], gt_value_buffer[i])
113
+
114
+ print("================================================")
115
+ print("TEST_LOAD_STORE_METADATA PASSED!")
116
+ print("================================================")
117
+ connector.close()
118
+
119
+
120
+ if __name__ == "__main__":
121
+ test_load_store_metadata()
@@ -7,11 +7,16 @@ from typing import Any, List, Optional
7
7
 
8
8
  import torch
9
9
 
10
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
10
+ from sglang.srt.mem_cache.hicache_storage import (
11
+ HiCacheStorage,
12
+ HiCacheStorageConfig,
13
+ HiCacheStorageExtraInfo,
14
+ )
15
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
11
16
 
12
17
  DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
13
18
  DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
14
-
19
+ DEFAULT_MOONCAKE_CONFIG_PATH_ENV = "SGLANG_HICACHE_MOONCAKE_CONFIG_PATH"
15
20
  logger = logging.getLogger(__name__)
16
21
 
17
22
 
@@ -28,13 +33,13 @@ class MooncakeStoreConfig:
28
33
  @staticmethod
29
34
  def from_file() -> "MooncakeStoreConfig":
30
35
  """Load the config from a JSON file."""
31
- file_path = os.getenv("MOONCAKE_CONFIG_PATH")
32
- if file_path is None:
33
- raise ValueError(
34
- "The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
35
- )
36
- with open(file_path) as fin:
37
- config = json.load(fin)
36
+ file_path = os.getenv(DEFAULT_MOONCAKE_CONFIG_PATH_ENV)
37
+ try:
38
+ with open(file_path) as fin:
39
+ config = json.load(fin)
40
+ except Exception as e:
41
+ raise RuntimeError(f"Failed to load config from {file_path}: {str(e)}")
42
+
38
43
  return MooncakeStoreConfig(
39
44
  local_hostname=config.get("local_hostname"),
40
45
  metadata_server=config.get("metadata_server"),
@@ -72,6 +77,26 @@ class MooncakeStoreConfig:
72
77
  master_server_address=os.getenv("MOONCAKE_MASTER"),
73
78
  )
74
79
 
80
+ @staticmethod
81
+ def load_from_extra_config(extra_config: dict) -> "MooncakeStoreConfig":
82
+ """Load config from extra_config dictionary."""
83
+ if "master_server_address" not in extra_config:
84
+ raise ValueError("master_server_address is required in extra_config")
85
+
86
+ return MooncakeStoreConfig(
87
+ local_hostname=extra_config.get("local_hostname", "localhost"),
88
+ metadata_server=extra_config.get("metadata_server", "P2PHANDSHAKE"),
89
+ global_segment_size=extra_config.get(
90
+ "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
91
+ ),
92
+ local_buffer_size=extra_config.get(
93
+ "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
94
+ ),
95
+ protocol=extra_config.get("protocol", "tcp"),
96
+ device_name=extra_config.get("device_name", "auto"),
97
+ master_server_address=extra_config["master_server_address"],
98
+ )
99
+
75
100
  def __post_init__(self):
76
101
  if self.device_name == "auto":
77
102
  os.environ["MC_MS_AUTO_DISC"] = "1"
@@ -81,6 +106,7 @@ class MooncakeStoreConfig:
81
106
 
82
107
 
83
108
  class MooncakeStore(HiCacheStorage):
109
+
84
110
  def __init__(self, storage_config: HiCacheStorageConfig = None):
85
111
  try:
86
112
  from mooncake.store import MooncakeDistributedStore
@@ -93,14 +119,43 @@ class MooncakeStore(HiCacheStorage):
93
119
 
94
120
  try:
95
121
  self.store = MooncakeDistributedStore()
96
- self.config = MooncakeStoreConfig.load_from_env()
97
- logger.info("Mooncake Configuration loaded from env successfully.")
122
+
123
+ extra_config = (
124
+ getattr(storage_config, "extra_config", None)
125
+ if storage_config
126
+ else None
127
+ )
128
+ # Load configuration with master_server_address prioritized from extra_config if available
129
+ if (
130
+ extra_config is not None
131
+ and extra_config.get("master_server_address") is not None
132
+ ):
133
+ # Load from extra_config
134
+ self.config = MooncakeStoreConfig.load_from_extra_config(extra_config)
135
+ logger.info(
136
+ "Mooncake Configuration loaded from extra_config successfully."
137
+ )
138
+ elif os.getenv(DEFAULT_MOONCAKE_CONFIG_PATH_ENV):
139
+ # Load from config file
140
+ self.config = MooncakeStoreConfig.from_file()
141
+ logger.info("Mooncake Configuration loaded from file successfully.")
142
+ else:
143
+ # Load from environment variables
144
+ self.config = MooncakeStoreConfig.load_from_env()
145
+ logger.info("Mooncake Configuration loaded from env successfully.")
146
+
147
+ tp_scale_factor = 1 if storage_config is None else storage_config.tp_size
148
+
149
+ per_tp_global_segment_size = (
150
+ self.config.global_segment_size // tp_scale_factor
151
+ )
152
+ per_tp_local_buffer_size = self.config.local_buffer_size // tp_scale_factor
98
153
 
99
154
  ret_code = self.store.setup(
100
155
  self.config.local_hostname,
101
156
  self.config.metadata_server,
102
- self.config.global_segment_size,
103
- self.config.local_buffer_size,
157
+ per_tp_global_segment_size,
158
+ per_tp_local_buffer_size,
104
159
  self.config.protocol,
105
160
  self.config.device_name,
106
161
  self.config.master_server_address,
@@ -133,7 +188,13 @@ class MooncakeStore(HiCacheStorage):
133
188
  assert self.store.is_exist(warmup_key) == 1
134
189
  assert self.store.get(warmup_key) == warmup_value
135
190
 
136
- def register_buffer(self, buffer: torch.Tensor) -> None:
191
+ def register_mem_pool_host(self, mem_pool_host: HostKVCache):
192
+ super().register_mem_pool_host(mem_pool_host)
193
+ assert self.mem_pool_host.layout in [
194
+ "page_first",
195
+ "page_first_direct",
196
+ ], "mooncake store storage backend only support page first or page first direct layout"
197
+ buffer = self.mem_pool_host.kv_buffer
137
198
  try:
138
199
  buffer_ptr = buffer.data_ptr()
139
200
  buffer_size = buffer.numel() * buffer.element_size()
@@ -144,6 +205,97 @@ class MooncakeStore(HiCacheStorage):
144
205
  logger.error("Failed to register buffer to Mooncake Store: %s", err)
145
206
  raise TypeError("Mooncake Store Register Buffer Error.") from err
146
207
 
208
+ def _get_mha_buffer_meta(self, keys, indices):
209
+ ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
210
+ key_list = []
211
+ for key_ in keys:
212
+ key_list.append(f"{key_}_{self.local_rank}_k")
213
+ key_list.append(f"{key_}_{self.local_rank}_v")
214
+ assert len(key_list) == len(ptr_list)
215
+ return key_list, ptr_list, element_size_list
216
+
217
+ def _get_mla_buffer_meta(self, keys, indices):
218
+ ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
219
+ key_list = []
220
+ for key_ in keys:
221
+ key_list.append(f"{key_}_k")
222
+ assert len(key_list) == len(ptr_list)
223
+ return key_list, ptr_list, element_size_list
224
+
225
+ def _batch_preprocess(self, keys, host_indices):
226
+ assert len(keys) > 0
227
+ assert len(keys) == len(host_indices) // self.mem_pool_host.page_size
228
+ if self.is_mla_backend:
229
+ return self._get_mla_buffer_meta(keys, host_indices)
230
+ else:
231
+ return self._get_mha_buffer_meta(keys, host_indices)
232
+
233
+ def _batch_postprocess(self, results: List[int], is_set_operate=False):
234
+ """
235
+ refer to https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-store/include/pybind_client.h
236
+ for batch_get_into, results is Vector of integers,
237
+ where each element is the number of bytes read on success, or a negative value on error
238
+ for batch_put_from, results is Vector of integers,
239
+ where each element is 0 on success, or a negative value on error
240
+ """
241
+ if self.is_mla_backend:
242
+ return [k_res == 0 if is_set_operate else k_res > 0 for k_res in results]
243
+ else:
244
+ kv_pairs = zip(results[::2], results[1::2])
245
+ return [
246
+ (
247
+ (k_res == 0 and v_res == 0)
248
+ if is_set_operate
249
+ else (k_res > 0 and v_res > 0)
250
+ )
251
+ for k_res, v_res in kv_pairs
252
+ ]
253
+
254
+ def batch_get_v1(
255
+ self,
256
+ keys: List[str],
257
+ host_indices: torch.Tensor,
258
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
259
+ ) -> List[bool]:
260
+ key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
261
+ get_results = self._get_batch_zero_copy_impl(
262
+ key_strs, buffer_ptrs, buffer_sizes
263
+ )
264
+ return self._batch_postprocess(get_results, is_set_operate=False)
265
+
266
+ def batch_set_v1(
267
+ self,
268
+ keys: List[str],
269
+ host_indices: torch.Tensor,
270
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
271
+ ) -> List[bool]:
272
+ key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
273
+ exist_result = self._batch_exist(key_strs)
274
+
275
+ set_keys = []
276
+ set_buffer_ptrs = []
277
+ set_buffer_sizes = []
278
+ set_indices = []
279
+ set_results = [-1] * len(key_strs)
280
+ for i in range(len(key_strs)):
281
+ if exist_result[i] != 1:
282
+ set_keys.append(key_strs[i])
283
+ set_buffer_ptrs.append(buffer_ptrs[i])
284
+ set_buffer_sizes.append(buffer_sizes[i])
285
+ set_indices.append(i)
286
+ else:
287
+ set_results[i] = 0
288
+
289
+ # Only set non-existing keys to storage
290
+ if len(set_keys) > 0:
291
+ put_results = self._put_batch_zero_copy_impl(
292
+ set_keys, set_buffer_ptrs, set_buffer_sizes
293
+ )
294
+ for i in range(len(set_indices)):
295
+ set_results[set_indices[i]] = put_results[i]
296
+
297
+ return self._batch_postprocess(set_results, is_set_operate=True)
298
+
147
299
  def set(
148
300
  self,
149
301
  key,
@@ -264,9 +416,6 @@ class MooncakeStore(HiCacheStorage):
264
416
  return i // key_multiplier
265
417
  return len(query_keys) // key_multiplier
266
418
 
267
- def delete(self, key) -> None:
268
- raise (NotImplementedError)
269
-
270
419
  def close(self):
271
420
  # MooncakeDistributedStore will automatically call the destructor, so
272
421
  # it is unnecessary to close it manually.