sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,115 @@
1
+ import argparse
2
+ import os
3
+
4
+ import eic
5
+ import torch
6
+ import yaml
7
+
8
+
9
+ def pase_args():
10
+ parser = argparse.ArgumentParser(description="EIC Storage Unit Test")
11
+ parser.add_argument(
12
+ "--config",
13
+ "-c",
14
+ type=str,
15
+ default="/sgl-workspace/config/remote-eic.yaml",
16
+ help="EIC yaml config",
17
+ )
18
+ args, _ = parser.parse_known_args()
19
+ return args
20
+
21
+
22
+ def init_eic_client():
23
+ args = pase_args()
24
+ config_path = os.path.abspath(args.config)
25
+ if not os.path.exists(config_path):
26
+ raise FileNotFoundError(f"Config file not found: {config_path}")
27
+ with open(config_path, "r") as fin:
28
+ config = yaml.safe_load(fin)
29
+
30
+ remote_url = config.get("remote_url", None)
31
+ if remote_url is None:
32
+ AssertionError("remote_url is None")
33
+ endpoint = remote_url[len("eic://") :]
34
+ eic_instance_id = config.get("eic_instance_id", None)
35
+ eic_log_dir = config.get("eic_log_dir", None)
36
+ eic_log_level = config.get("eic_log_level", 2)
37
+ eic_trans_type = config.get("eic_trans_type", 3)
38
+ eic_flag_file = config.get("eic_flag_file", None)
39
+
40
+ if not os.path.exists(eic_log_dir):
41
+ os.makedirs(eic_log_dir, exist_ok=True)
42
+ eic_client = eic.Client()
43
+ init_option = eic.InitOption()
44
+ init_option.log_dir = eic_log_dir
45
+ init_option.log_level = eic.LogLevel(eic_log_level)
46
+ init_option.transport_type = eic.TransportType(eic_trans_type)
47
+ init_option.flag_file = eic_flag_file
48
+ ret = eic_client.init(eic_instance_id, endpoint, init_option)
49
+ if ret != 0:
50
+ raise RuntimeError(f"EIC Client init failed with error code: {ret}")
51
+ return eic_client
52
+
53
+
54
+ def test_set(eic_client):
55
+ test_key = ["test_key_" + str(i) for i in range(16)]
56
+ tensors = [
57
+ torch.ones([12, 6, 1, 512], dtype=torch.bfloat16, device="cpu")
58
+ for _ in range(16)
59
+ ]
60
+ data_keys = eic.StringVector()
61
+ data_vals = eic.IOBuffers()
62
+ for i in range(16):
63
+ data_keys.append(test_key[i])
64
+ data_vals.append(
65
+ tensors[i].data_ptr(), tensors[i].numel() * tensors[i].element_size(), False
66
+ )
67
+ set_opt = eic.SetOption()
68
+ set_opt.ttl_second = 3
69
+ status_code, set_outcome = eic_client.mset(data_keys, data_vals, set_opt)
70
+ assert (
71
+ status_code == eic.StatusCode.SUCCESS
72
+ ), f"Set failed with status code: {status_code}"
73
+
74
+
75
+ def test_get(eic_client):
76
+ test_key = ["test_key_" + str(i) for i in range(16)]
77
+ tensors = [
78
+ torch.zeros([12, 6, 1, 512], dtype=torch.bfloat16, device="cpu")
79
+ for _ in range(16)
80
+ ]
81
+ data_keys = eic.StringVector()
82
+ data_vals = eic.IOBuffers()
83
+ for i in range(16):
84
+ data_keys.append(test_key[i])
85
+ data_vals.append(
86
+ tensors[i].data_ptr(), tensors[i].numel() * tensors[i].element_size(), False
87
+ )
88
+ get_opt = eic.GetOption()
89
+ status_code, data_vals, get_outcome = eic_client.mget(data_keys, get_opt, data_vals)
90
+ assert (
91
+ status_code == eic.StatusCode.SUCCESS
92
+ ), f"Get failed with status code: {status_code}"
93
+
94
+
95
+ def test_exists(eic_client):
96
+ test_key = ["test_key_" + str(i) for i in range(16)]
97
+ data_keys = eic.StringVector()
98
+ for key in test_key:
99
+ data_keys.append(key)
100
+ exists_opt = eic.ExistOption()
101
+ status_code, exists_outcome = eic_client.mexist(data_keys, exists_opt)
102
+ assert (
103
+ status_code == eic.StatusCode.SUCCESS
104
+ ), f"Exists failed with status code: {status_code}"
105
+
106
+
107
+ def main():
108
+ eic_client = init_eic_client()
109
+ test_set(eic_client)
110
+ test_exists(eic_client)
111
+ test_get(eic_client)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
@@ -12,7 +12,12 @@ from typing import Any, List, Optional, Tuple
12
12
 
13
13
  import torch
14
14
 
15
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
15
+ from sglang.srt.mem_cache.hicache_storage import (
16
+ HiCacheStorage,
17
+ HiCacheStorageConfig,
18
+ HiCacheStorageExtraInfo,
19
+ )
20
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
16
21
  from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
17
22
  from sglang.srt.metrics.collector import StorageMetrics
18
23
 
@@ -178,11 +183,14 @@ class HiCacheHF3FS(HiCacheStorage):
178
183
  self.skip_backup = True
179
184
  self.rank = 0
180
185
 
186
+ self.is_zero_copy = False
187
+
181
188
  logger.info(
182
189
  f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
183
190
  f"file_path={self.file_path}, "
184
191
  f"file_size={self.file_size / (2 ** 30):.2f} GB, "
185
- f"num_pages={self.num_pages}"
192
+ f"num_pages={self.num_pages}, "
193
+ f"is_mla_model={self.is_mla_model}"
186
194
  )
187
195
 
188
196
  self.ac = AtomicCounter(self.numjobs)
@@ -323,25 +331,12 @@ class HiCacheHF3FS(HiCacheStorage):
323
331
  use_mock_client=use_mock_client,
324
332
  )
325
333
 
326
- def get(
327
- self,
328
- key: str,
329
- target_location: Optional[Any] = None,
330
- target_sizes: Optional[Any] = None,
331
- ) -> torch.Tensor | None:
332
- return self.batch_get(
333
- [key],
334
- [target_location] if target_location is not None else None,
335
- [target_sizes] if target_sizes is not None else None,
336
- )[0]
337
-
338
334
  @synchronized()
339
- def batch_get(
335
+ def _batch_get(
340
336
  self,
341
337
  keys: List[str],
342
- target_locations: Optional[Any] = None,
343
- target_sizes: Optional[Any] = None,
344
- ) -> List[torch.Tensor | None]:
338
+ values: List[torch.Tensor],
339
+ ) -> List[bool]:
345
340
  page_indices = self.metadata_client.get_page_indices(self.rank, keys)
346
341
 
347
342
  batch_indices, file_offsets = [], []
@@ -350,15 +345,9 @@ class HiCacheHF3FS(HiCacheStorage):
350
345
  batch_indices.append(i)
351
346
  file_offsets.append(page_index * self.bytes_per_page)
352
347
 
353
- if target_locations is not None:
354
- for target_location in target_locations:
355
- assert target_location.is_contiguous()
356
- file_results = target_locations
357
- else:
358
- file_results = [
359
- torch.empty(self.numel, dtype=self.dtype)
360
- for _ in range(len(batch_indices))
361
- ]
348
+ for target_location in values:
349
+ assert target_location.is_contiguous()
350
+ file_results = values
362
351
 
363
352
  start_time = time.perf_counter()
364
353
 
@@ -379,12 +368,10 @@ class HiCacheHF3FS(HiCacheStorage):
379
368
  ionum / (end_time - start_time) * self.gb_per_page
380
369
  )
381
370
 
382
- results = [None] * len(keys)
383
- for batch_index, file_result, read_result in zip(
384
- batch_indices, file_results, read_results
385
- ):
371
+ results = [False] * len(keys)
372
+ for batch_index, read_result in zip(batch_indices, read_results):
386
373
  if read_result == self.bytes_per_page:
387
- results[batch_index] = file_result
374
+ results[batch_index] = True
388
375
  else:
389
376
  logger.error(
390
377
  f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
@@ -392,28 +379,12 @@ class HiCacheHF3FS(HiCacheStorage):
392
379
 
393
380
  return results
394
381
 
395
- def set(
396
- self,
397
- key: str,
398
- value: Optional[Any] = None,
399
- target_location: Optional[Any] = None,
400
- target_sizes: Optional[Any] = None,
401
- ) -> bool:
402
- return self.batch_set(
403
- [key],
404
- [value] if value is not None else None,
405
- [target_location] if target_location is not None else None,
406
- [target_sizes] if target_sizes is not None else None,
407
- )
408
-
409
382
  @synchronized()
410
- def batch_set(
383
+ def _batch_set(
411
384
  self,
412
385
  keys: List[str],
413
386
  values: Optional[Any] = None,
414
- target_locations: Optional[Any] = None,
415
- target_sizes: Optional[Any] = None,
416
- ) -> bool:
387
+ ) -> List[bool]:
417
388
  # In MLA backend, only one rank needs to backup the KV cache
418
389
  if self.skip_backup:
419
390
  return True
@@ -474,7 +445,7 @@ class HiCacheHF3FS(HiCacheStorage):
474
445
  self.rank, written_keys_to_confirm, pages_to_release
475
446
  )
476
447
 
477
- return all(results)
448
+ return results
478
449
 
479
450
  def delete(self, key: str) -> None:
480
451
  self.metadata_client.delete_keys(self.rank, [key])
@@ -484,21 +455,25 @@ class HiCacheHF3FS(HiCacheStorage):
484
455
  return result[0] if result else False
485
456
 
486
457
  def batch_exists(self, keys: List[str]) -> int:
458
+ factor = 1
459
+ if self.is_zero_copy and not self.is_mla_model:
460
+ keys = self._get_mha_zero_copy_keys(keys)
461
+ factor = 2
462
+
487
463
  results = self.metadata_client.exists(self.rank, keys)
488
- for i in range(len(keys)):
489
- if not results[i]:
490
- return i
491
464
 
492
- return len(keys)
465
+ i = 0
466
+ while i < len(keys) and results[i]:
467
+ i += 1
493
468
 
494
- def clear(self) -> bool:
469
+ return i // factor
470
+
471
+ def clear(self) -> None:
495
472
  try:
496
473
  self.metadata_client.clear(self.rank)
497
474
  logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
498
- return True
499
475
  except Exception as e:
500
476
  logger.error(f"Failed to clear HiCacheHF3FS: {e}")
501
- return False
502
477
 
503
478
  def close(self) -> None:
504
479
  try:
@@ -521,3 +496,143 @@ class HiCacheHF3FS(HiCacheStorage):
521
496
  self.prefetch_bandwidth.clear()
522
497
  self.backup_bandwidth.clear()
523
498
  return storage_metrics
499
+
500
+ def register_mem_pool_host(self, mem_pool_host: HostKVCache):
501
+ super().register_mem_pool_host(mem_pool_host)
502
+ self.is_zero_copy = self.mem_pool_host.layout == "page_first"
503
+ logger.info(f"{self.is_zero_copy=}")
504
+
505
+ def _get_mha_zero_copy_keys(self, keys: List[str]) -> List[str]:
506
+ _keys = []
507
+ for k in keys:
508
+ _keys.append(f"{k}-k")
509
+ _keys.append(f"{k}-v")
510
+ return _keys
511
+
512
+ def _get_mha_zero_copy_values(
513
+ self, values: List[torch.Tensor]
514
+ ) -> List[torch.Tensor]:
515
+ _values = []
516
+ for value in values:
517
+ _values.append(value[0])
518
+ _values.append(value[1])
519
+ return _values
520
+
521
+ def _batch_get_preprocess(self, keys, host_indices):
522
+ page_num = len(host_indices) // self.mem_pool_host.page_size
523
+ # host_indices to kv_buffer
524
+ flat = not self.is_zero_copy
525
+ values = (
526
+ [
527
+ self.mem_pool_host.get_data_page(
528
+ host_indices[i * self.mem_pool_host.page_size], flat=flat
529
+ )
530
+ for i in range(page_num)
531
+ ]
532
+ if self.is_zero_copy
533
+ else [
534
+ self.mem_pool_host.get_dummy_flat_data_page() for _ in range(page_num)
535
+ ]
536
+ )
537
+
538
+ if self.is_zero_copy and not self.is_mla_model:
539
+ keys = self._get_mha_zero_copy_keys(keys)
540
+ values = self._get_mha_zero_copy_values(values)
541
+
542
+ return keys, values
543
+
544
+ def _batch_get_postprocess(self, host_indices, values, results):
545
+ page_num = len(host_indices) // self.mem_pool_host.page_size
546
+
547
+ if self.is_zero_copy:
548
+ if not self.is_mla_model:
549
+ results = [
550
+ (results[2 * i] and results[2 * i + 1]) for i in range(page_num)
551
+ ]
552
+ results = results[:page_num]
553
+ return results
554
+
555
+ for i in range(page_num):
556
+ if not results[i]:
557
+ break
558
+ self.mem_pool_host.set_from_flat_data_page(
559
+ host_indices[i * self.mem_pool_host.page_size], values[i]
560
+ )
561
+
562
+ return results
563
+
564
+ def batch_get_v1(
565
+ self,
566
+ keys: List[str],
567
+ host_indices: torch.Tensor,
568
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
569
+ ) -> List[bool]:
570
+ keys, values = self._batch_get_preprocess(keys, host_indices)
571
+ results = self._batch_get(keys, values)
572
+ return self._batch_get_postprocess(host_indices, values, results)
573
+
574
+ def _batch_set_preprocess(self, keys, host_indices):
575
+ page_num = len(host_indices) // self.mem_pool_host.page_size
576
+ # host_indices to kv_buffer
577
+ flat = not self.is_zero_copy
578
+ values = [
579
+ self.mem_pool_host.get_data_page(
580
+ host_indices[i * self.mem_pool_host.page_size], flat=flat
581
+ )
582
+ for i in range(page_num)
583
+ ]
584
+
585
+ if self.is_zero_copy and not self.is_mla_model:
586
+ keys = self._get_mha_zero_copy_keys(keys)
587
+ values = self._get_mha_zero_copy_values(values)
588
+
589
+ return keys, values
590
+
591
+ def batch_set_v1(
592
+ self,
593
+ keys: List[str],
594
+ host_indices: torch.Tensor,
595
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
596
+ ) -> List[bool]:
597
+ len_keys = len(keys)
598
+ keys, values = self._batch_set_preprocess(keys, host_indices)
599
+ results = self._batch_set(keys, values)
600
+ return results
601
+
602
+ # Deprecated
603
+ def get(
604
+ self,
605
+ key: str,
606
+ target_location: Optional[Any] = None,
607
+ target_sizes: Optional[Any] = None,
608
+ ) -> torch.Tensor | None:
609
+ pass
610
+
611
+ # Deprecated
612
+ def batch_get(
613
+ self,
614
+ keys: List[str],
615
+ target_locations: Optional[Any] = None,
616
+ target_sizes: Optional[Any] = None,
617
+ ) -> List[torch.Tensor | None] | int:
618
+ pass
619
+
620
+ # Deprecated
621
+ def set(
622
+ self,
623
+ key: str,
624
+ value: Optional[Any] = None,
625
+ target_location: Optional[Any] = None,
626
+ target_sizes: Optional[Any] = None,
627
+ ) -> bool:
628
+ pass
629
+
630
+ # Deprecated
631
+ def batch_set(
632
+ self,
633
+ keys: List[str],
634
+ values: Optional[Any] = None,
635
+ target_locations: Optional[Any] = None,
636
+ target_sizes: Optional[Any] = None,
637
+ ) -> bool:
638
+ pass
@@ -9,7 +9,7 @@ import torch
9
9
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
10
10
  from sglang.srt.mem_cache.base_prefix_cache import MatchResult
11
11
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
12
- from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
12
+ from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
13
13
 
14
14
  try:
15
15
  from lmcache.integration.sglang.sglang_adapter import (
@@ -78,6 +78,7 @@ class LMCRadixCache(RadixCache):
78
78
  tp_size: int = 1,
79
79
  rank: int = 0,
80
80
  tp_group: Optional[torch.distributed.ProcessGroup] = None,
81
+ eviction_policy: str = "lru",
81
82
  ):
82
83
  super().__init__(
83
84
  req_to_token_pool=req_to_token_pool,
@@ -85,6 +86,7 @@ class LMCRadixCache(RadixCache):
85
86
  page_size=page_size,
86
87
  disable=disable,
87
88
  enable_kv_cache_events=enable_kv_cache_events,
89
+ eviction_policy=eviction_policy,
88
90
  )
89
91
 
90
92
  kvcache = self.token_to_kv_pool_allocator.get_kvcache()
@@ -129,7 +131,7 @@ class LMCRadixCache(RadixCache):
129
131
  with self._node_lock:
130
132
  self._in_flight_nodes.clear()
131
133
 
132
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override]
134
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # type: ignore[override]
133
135
  """Match cached prefix; if there's a tail miss, prefetch from LMCache.
134
136
 
135
137
  Reuses the base matching logic to obtain (value, last_node). If there
@@ -176,7 +178,7 @@ class LMCRadixCache(RadixCache):
176
178
  with torch.cuda.stream(self.load_stream):
177
179
  num_retrieved = self.lmcache_connector.start_load_kv(
178
180
  LoadMetadata(
179
- token_ids=key, # full page-aligned key
181
+ token_ids=key.token_ids, # full page-aligned key
180
182
  slot_mapping=slot_mapping,
181
183
  offset=value.numel() - prefix_pad, # LMCache offset convention
182
184
  )
@@ -225,7 +227,7 @@ class LMCRadixCache(RadixCache):
225
227
  req.req_pool_idx, : len(token_ids)
226
228
  ]
227
229
 
228
- _, new_last_node, _, _ = self.match_prefix(token_ids)
230
+ _, new_last_node, _, _ = self.match_prefix(RadixKey(token_ids, req.extra_key))
229
231
  assert new_last_node is not None
230
232
 
231
233
  self.inc_lock_ref(new_last_node)
@@ -275,6 +277,8 @@ if __name__ == "__main__":
275
277
  rank=0,
276
278
  tp_group=None,
277
279
  )
278
- cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64))
279
- cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64))
280
+ cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 11, 12], dtype=torch.int64))
281
+ cache.insert(
282
+ RadixKey([1, 2, 3, 4]), torch.tensor([10, 11, 12, 13], dtype=torch.int64)
283
+ )
280
284
  cache.pretty_print()
@@ -7,11 +7,16 @@ from typing import Any, List, Optional
7
7
 
8
8
  import torch
9
9
 
10
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
10
+ from sglang.srt.mem_cache.hicache_storage import (
11
+ HiCacheStorage,
12
+ HiCacheStorageConfig,
13
+ HiCacheStorageExtraInfo,
14
+ )
15
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
11
16
 
12
17
  DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
13
18
  DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
14
-
19
+ DEFAULT_MOONCAKE_CONFIG_PATH_ENV = "SGLANG_HICACHE_MOONCAKE_CONFIG_PATH"
15
20
  logger = logging.getLogger(__name__)
16
21
 
17
22
 
@@ -28,13 +33,13 @@ class MooncakeStoreConfig:
28
33
  @staticmethod
29
34
  def from_file() -> "MooncakeStoreConfig":
30
35
  """Load the config from a JSON file."""
31
- file_path = os.getenv("MOONCAKE_CONFIG_PATH")
32
- if file_path is None:
33
- raise ValueError(
34
- "The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
35
- )
36
- with open(file_path) as fin:
37
- config = json.load(fin)
36
+ file_path = os.getenv(DEFAULT_MOONCAKE_CONFIG_PATH_ENV)
37
+ try:
38
+ with open(file_path) as fin:
39
+ config = json.load(fin)
40
+ except Exception as e:
41
+ raise RuntimeError(f"Failed to load config from {file_path}: {str(e)}")
42
+
38
43
  return MooncakeStoreConfig(
39
44
  local_hostname=config.get("local_hostname"),
40
45
  metadata_server=config.get("metadata_server"),
@@ -101,6 +106,7 @@ class MooncakeStoreConfig:
101
106
 
102
107
 
103
108
  class MooncakeStore(HiCacheStorage):
109
+
104
110
  def __init__(self, storage_config: HiCacheStorageConfig = None):
105
111
  try:
106
112
  from mooncake.store import MooncakeDistributedStore
@@ -129,6 +135,10 @@ class MooncakeStore(HiCacheStorage):
129
135
  logger.info(
130
136
  "Mooncake Configuration loaded from extra_config successfully."
131
137
  )
138
+ elif os.getenv(DEFAULT_MOONCAKE_CONFIG_PATH_ENV):
139
+ # Load from config file
140
+ self.config = MooncakeStoreConfig.from_file()
141
+ logger.info("Mooncake Configuration loaded from file successfully.")
132
142
  else:
133
143
  # Load from environment variables
134
144
  self.config = MooncakeStoreConfig.load_from_env()
@@ -178,7 +188,13 @@ class MooncakeStore(HiCacheStorage):
178
188
  assert self.store.is_exist(warmup_key) == 1
179
189
  assert self.store.get(warmup_key) == warmup_value
180
190
 
181
- def register_buffer(self, buffer: torch.Tensor) -> None:
191
+ def register_mem_pool_host(self, mem_pool_host: HostKVCache):
192
+ super().register_mem_pool_host(mem_pool_host)
193
+ assert self.mem_pool_host.layout in [
194
+ "page_first",
195
+ "page_first_direct",
196
+ ], "mooncake store storage backend only support page first or page first direct layout"
197
+ buffer = self.mem_pool_host.kv_buffer
182
198
  try:
183
199
  buffer_ptr = buffer.data_ptr()
184
200
  buffer_size = buffer.numel() * buffer.element_size()
@@ -189,6 +205,97 @@ class MooncakeStore(HiCacheStorage):
189
205
  logger.error("Failed to register buffer to Mooncake Store: %s", err)
190
206
  raise TypeError("Mooncake Store Register Buffer Error.") from err
191
207
 
208
+ def _get_mha_buffer_meta(self, keys, indices):
209
+ ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
210
+ key_list = []
211
+ for key_ in keys:
212
+ key_list.append(f"{key_}_{self.local_rank}_k")
213
+ key_list.append(f"{key_}_{self.local_rank}_v")
214
+ assert len(key_list) == len(ptr_list)
215
+ return key_list, ptr_list, element_size_list
216
+
217
+ def _get_mla_buffer_meta(self, keys, indices):
218
+ ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
219
+ key_list = []
220
+ for key_ in keys:
221
+ key_list.append(f"{key_}_k")
222
+ assert len(key_list) == len(ptr_list)
223
+ return key_list, ptr_list, element_size_list
224
+
225
+ def _batch_preprocess(self, keys, host_indices):
226
+ assert len(keys) > 0
227
+ assert len(keys) == len(host_indices) // self.mem_pool_host.page_size
228
+ if self.is_mla_backend:
229
+ return self._get_mla_buffer_meta(keys, host_indices)
230
+ else:
231
+ return self._get_mha_buffer_meta(keys, host_indices)
232
+
233
+ def _batch_postprocess(self, results: List[int], is_set_operate=False):
234
+ """
235
+ refer to https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-store/include/pybind_client.h
236
+ for batch_get_into, results is Vector of integers,
237
+ where each element is the number of bytes read on success, or a negative value on error
238
+ for batch_put_from, results is Vector of integers,
239
+ where each element is 0 on success, or a negative value on error
240
+ """
241
+ if self.is_mla_backend:
242
+ return [k_res == 0 if is_set_operate else k_res > 0 for k_res in results]
243
+ else:
244
+ kv_pairs = zip(results[::2], results[1::2])
245
+ return [
246
+ (
247
+ (k_res == 0 and v_res == 0)
248
+ if is_set_operate
249
+ else (k_res > 0 and v_res > 0)
250
+ )
251
+ for k_res, v_res in kv_pairs
252
+ ]
253
+
254
+ def batch_get_v1(
255
+ self,
256
+ keys: List[str],
257
+ host_indices: torch.Tensor,
258
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
259
+ ) -> List[bool]:
260
+ key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
261
+ get_results = self._get_batch_zero_copy_impl(
262
+ key_strs, buffer_ptrs, buffer_sizes
263
+ )
264
+ return self._batch_postprocess(get_results, is_set_operate=False)
265
+
266
+ def batch_set_v1(
267
+ self,
268
+ keys: List[str],
269
+ host_indices: torch.Tensor,
270
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
271
+ ) -> List[bool]:
272
+ key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
273
+ exist_result = self._batch_exist(key_strs)
274
+
275
+ set_keys = []
276
+ set_buffer_ptrs = []
277
+ set_buffer_sizes = []
278
+ set_indices = []
279
+ set_results = [-1] * len(key_strs)
280
+ for i in range(len(key_strs)):
281
+ if exist_result[i] != 1:
282
+ set_keys.append(key_strs[i])
283
+ set_buffer_ptrs.append(buffer_ptrs[i])
284
+ set_buffer_sizes.append(buffer_sizes[i])
285
+ set_indices.append(i)
286
+ else:
287
+ set_results[i] = 0
288
+
289
+ # Only set non-existing keys to storage
290
+ if len(set_keys) > 0:
291
+ put_results = self._put_batch_zero_copy_impl(
292
+ set_keys, set_buffer_ptrs, set_buffer_sizes
293
+ )
294
+ for i in range(len(set_indices)):
295
+ set_results[set_indices[i]] = put_results[i]
296
+
297
+ return self._batch_postprocess(set_results, is_set_operate=True)
298
+
192
299
  def set(
193
300
  self,
194
301
  key,