sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,121 @@
1
+ try:
2
+ from lmcache.integration.sglang.sglang_adapter import (
3
+ LMCacheLayerwiseConnector,
4
+ LoadMetadata,
5
+ StoreMetadata,
6
+ )
7
+ except ImportError:
8
+ raise RuntimeError(
9
+ "LMCache is not installed. Please install it by running `pip install lmcache` in the root directory of LMCache"
10
+ )
11
+
12
+ import os
13
+
14
+ import torch
15
+
16
+ from sglang.srt.configs.model_config import ModelConfig
17
+
18
+ os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
19
+ os.environ["LMCACHE_CONFIG_FILE"] = "example_config.yaml"
20
+
21
+
22
+ def test_load_store_metadata():
23
+ model_config = ModelConfig(
24
+ model_path="Qwen/Qwen3-4B",
25
+ )
26
+
27
+ # Generate Dummy KV Cache
28
+ head_num = model_config.num_key_value_heads
29
+ head_dim = model_config.head_dim
30
+ layer_num = model_config.num_hidden_layers
31
+ buffer_size = 256
32
+ input_id_len = 16
33
+
34
+ k_buffer = [
35
+ torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
36
+ for _ in range(layer_num)
37
+ ]
38
+ v_buffer = [
39
+ torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
40
+ for _ in range(layer_num)
41
+ ]
42
+
43
+ connector = LMCacheLayerwiseConnector(model_config, 1, 0, k_buffer, v_buffer)
44
+
45
+ fake_token_ids = torch.randint(0, model_config.vocab_size, (input_id_len,)).tolist()
46
+ fake_kv_indices = torch.randint(0, buffer_size, (input_id_len,))
47
+ offset = 0
48
+
49
+ store_metadata = StoreMetadata(
50
+ last_node=None,
51
+ token_ids=fake_token_ids,
52
+ kv_indices=fake_kv_indices,
53
+ offset=offset,
54
+ )
55
+
56
+ load_metadata = LoadMetadata(
57
+ token_ids=fake_token_ids,
58
+ slot_mapping=fake_kv_indices,
59
+ offset=offset,
60
+ )
61
+
62
+ current_stream = torch.cuda.current_stream()
63
+
64
+ retrieve_token_num = connector.start_load_kv(load_metadata)
65
+ assert retrieve_token_num == 0
66
+
67
+ connector.store_kv(store_metadata)
68
+ current_stream.synchronize()
69
+
70
+ # check retrieve
71
+ gt_key_buffer = [
72
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
73
+ for _ in range(layer_num)
74
+ ]
75
+ gt_value_buffer = [
76
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
77
+ for _ in range(layer_num)
78
+ ]
79
+
80
+ for i in range(layer_num):
81
+ gt_key_buffer[i] = k_buffer[i][fake_kv_indices]
82
+ gt_value_buffer[i] = v_buffer[i][fake_kv_indices]
83
+
84
+ # clear the k_buffer and v_buffer
85
+ for _ in range(layer_num):
86
+ k_buffer[i].zero_()
87
+ v_buffer[i].zero_()
88
+
89
+ retrieve_token_num = connector.start_load_kv(load_metadata)
90
+ assert retrieve_token_num == input_id_len
91
+
92
+ for i in range(layer_num):
93
+ current_stream.synchronize()
94
+ connector.load_kv_layerwise(i)
95
+
96
+ current_stream.synchronize()
97
+ test_key_buffer = [
98
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
99
+ for _ in range(layer_num)
100
+ ]
101
+ test_value_buffer = [
102
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
103
+ for _ in range(layer_num)
104
+ ]
105
+
106
+ for i in range(layer_num):
107
+ test_key_buffer[i] = k_buffer[i][fake_kv_indices]
108
+ test_value_buffer[i] = v_buffer[i][fake_kv_indices]
109
+
110
+ for i in range(layer_num):
111
+ assert torch.allclose(test_key_buffer[i], gt_key_buffer[i])
112
+ assert torch.allclose(test_value_buffer[i], gt_value_buffer[i])
113
+
114
+ print("================================================")
115
+ print("TEST_LOAD_STORE_METADATA PASSED!")
116
+ print("================================================")
117
+ connector.close()
118
+
119
+
120
+ if __name__ == "__main__":
121
+ test_load_store_metadata()
@@ -1,4 +1,3 @@
1
- import hashlib
2
1
  import json
3
2
  import logging
4
3
  import os
@@ -6,28 +5,16 @@ import uuid
6
5
  from dataclasses import dataclass
7
6
  from typing import Any, List, Optional
8
7
 
9
- import numpy as np
10
8
  import torch
11
9
 
12
- from sglang.srt.distributed import get_tensor_model_parallel_rank
13
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
10
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
14
11
 
15
12
  DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
16
- DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
13
+ DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
17
14
 
18
15
  logger = logging.getLogger(__name__)
19
16
 
20
17
 
21
- def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
22
- prefix_str = ""
23
- if prior_hash:
24
- prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
25
- current_token_ids_bytes = np.array(token_ids).tobytes()
26
- current_hash_object = hashlib.sha256(current_token_ids_bytes)
27
- current_hash_hex = current_hash_object.hexdigest()
28
- return f"{prefix_str}_{int(current_hash_hex[:16], 16)}"
29
-
30
-
31
18
  @dataclass
32
19
  class MooncakeStoreConfig:
33
20
  local_hostname: str
@@ -54,9 +41,8 @@ class MooncakeStoreConfig:
54
41
  global_segment_size=config.get(
55
42
  "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
56
43
  ),
57
- local_buffer_size=config.get(
58
- "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
59
- ),
44
+ # Zero copy interface does not need local buffer
45
+ local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
60
46
  protocol=config.get("protocol", "tcp"),
61
47
  device_name=config.get("device_name", "auto"),
62
48
  master_server_address=config.get("master_server_address"),
@@ -79,14 +65,33 @@ class MooncakeStoreConfig:
79
65
  global_segment_size=int(
80
66
  os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
81
67
  ),
82
- local_buffer_size=int(
83
- os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE)
84
- ),
68
+ # Zero copy interface does not need local buffer
69
+ local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
85
70
  protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
86
71
  device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
87
72
  master_server_address=os.getenv("MOONCAKE_MASTER"),
88
73
  )
89
74
 
75
+ @staticmethod
76
+ def load_from_extra_config(extra_config: dict) -> "MooncakeStoreConfig":
77
+ """Load config from extra_config dictionary."""
78
+ if "master_server_address" not in extra_config:
79
+ raise ValueError("master_server_address is required in extra_config")
80
+
81
+ return MooncakeStoreConfig(
82
+ local_hostname=extra_config.get("local_hostname", "localhost"),
83
+ metadata_server=extra_config.get("metadata_server", "P2PHANDSHAKE"),
84
+ global_segment_size=extra_config.get(
85
+ "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
86
+ ),
87
+ local_buffer_size=extra_config.get(
88
+ "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
89
+ ),
90
+ protocol=extra_config.get("protocol", "tcp"),
91
+ device_name=extra_config.get("device_name", "auto"),
92
+ master_server_address=extra_config["master_server_address"],
93
+ )
94
+
90
95
  def __post_init__(self):
91
96
  if self.device_name == "auto":
92
97
  os.environ["MC_MS_AUTO_DISC"] = "1"
@@ -96,7 +101,7 @@ class MooncakeStoreConfig:
96
101
 
97
102
 
98
103
  class MooncakeStore(HiCacheStorage):
99
- def __init__(self, is_mla: bool = False):
104
+ def __init__(self, storage_config: HiCacheStorageConfig = None):
100
105
  try:
101
106
  from mooncake.store import MooncakeDistributedStore
102
107
  except ImportError as e:
@@ -108,8 +113,26 @@ class MooncakeStore(HiCacheStorage):
108
113
 
109
114
  try:
110
115
  self.store = MooncakeDistributedStore()
111
- self.config = MooncakeStoreConfig.load_from_env()
112
- logger.info("Mooncake Configuration loaded from env successfully.")
116
+
117
+ extra_config = (
118
+ getattr(storage_config, "extra_config", None)
119
+ if storage_config
120
+ else None
121
+ )
122
+ # Load configuration with master_server_address prioritized from extra_config if available
123
+ if (
124
+ extra_config is not None
125
+ and extra_config.get("master_server_address") is not None
126
+ ):
127
+ # Load from extra_config
128
+ self.config = MooncakeStoreConfig.load_from_extra_config(extra_config)
129
+ logger.info(
130
+ "Mooncake Configuration loaded from extra_config successfully."
131
+ )
132
+ else:
133
+ # Load from environment variables
134
+ self.config = MooncakeStoreConfig.load_from_env()
135
+ logger.info("Mooncake Configuration loaded from env successfully.")
113
136
 
114
137
  ret_code = self.store.setup(
115
138
  self.config.local_hostname,
@@ -126,7 +149,13 @@ class MooncakeStore(HiCacheStorage):
126
149
  logger.info("Connect to Mooncake store successfully.")
127
150
  self.warmup()
128
151
  logger.info("Mooncake store warmup successfully.")
129
- self.is_mla = is_mla
152
+
153
+ if storage_config is not None:
154
+ self.is_mla_backend = storage_config.is_mla_model
155
+ self.local_rank = storage_config.tp_rank
156
+ else:
157
+ self.is_mla_backend = False
158
+ self.local_rank = 0
130
159
 
131
160
  except ValueError as e:
132
161
  logger.error("Configuration loading failed: %s", e)
@@ -137,12 +166,10 @@ class MooncakeStore(HiCacheStorage):
137
166
 
138
167
  def warmup(self):
139
168
  warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
140
- # 10 MB
141
- warmup_value = bytes(10 * 1024 * 1024)
142
- self.store.put(warmup_key, warmup_value)
169
+ warmup_value = bytes(4 * 1024) # 4 KB
170
+ assert self.store.put(warmup_key, warmup_value) == 0
143
171
  assert self.store.is_exist(warmup_key) == 1
144
- self.store.get(warmup_key)
145
- self.store.remove(warmup_key)
172
+ assert self.store.get(warmup_key) == warmup_value
146
173
 
147
174
  def register_buffer(self, buffer: torch.Tensor) -> None:
148
175
  try:
@@ -162,81 +189,118 @@ class MooncakeStore(HiCacheStorage):
162
189
  target_location: Optional[List[int]] = None,
163
190
  target_sizes: Optional[List[int]] = None,
164
191
  ) -> bool:
165
- assert len(key) == len(target_location) == len(target_sizes)
166
- if len(key) == 0:
167
- return
168
-
169
- for i in range(len(key)):
170
- if key[i] is None or target_location[i] is None or target_sizes[i] is None:
171
- return
172
-
173
- self._put_batch_zero_copy_impl(key, target_location, target_sizes)
192
+ # Only support zero copy set for now
193
+ assert target_location is not None and target_sizes is not None
194
+ exist_result = self._batch_exist([key])
195
+ if exist_result[0] == 1:
196
+ return True
197
+ put_result = self._put_batch_zero_copy_impl(
198
+ [key], [target_location], [target_sizes]
199
+ )
200
+ return put_result[0] == 0
174
201
 
175
202
  def batch_set(
176
203
  self,
177
204
  keys: List[str],
178
- value: Optional[Any] = None,
179
- target_location: Optional[List[int]] = None,
205
+ values: Optional[List[torch.Tensor]] = None,
206
+ target_locations: Optional[List[int]] = None,
180
207
  target_sizes: Optional[List[int]] = None,
181
208
  ) -> bool:
182
- assert len(keys) == len(target_location) == len(target_sizes)
209
+ # Only support zero copy set for now
210
+ assert target_locations is not None and target_sizes is not None
211
+ assert len(keys) == len(target_locations) == len(target_sizes)
212
+
183
213
  if len(keys) == 0:
184
- return
214
+ return False
185
215
 
186
216
  for i in range(len(keys)):
187
- if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
188
- return
217
+ if (
218
+ keys[i] is None
219
+ or target_locations[i] is None
220
+ or target_sizes[i] is None
221
+ ):
222
+ return False
223
+
224
+ exist_result = self._batch_exist(keys)
225
+ set_keys = []
226
+ set_target_locations = []
227
+ set_target_sizes = []
228
+ set_indices = []
229
+ for i in range(len(keys)):
230
+ if exist_result[i] != 1:
231
+ set_keys.append(keys[i])
232
+ set_target_locations.append(target_locations[i])
233
+ set_target_sizes.append(target_sizes[i])
234
+ set_indices.append(i)
235
+ # Only set non-existing keys to storage
236
+ put_result = self._put_batch_zero_copy_impl(
237
+ set_keys, set_target_locations, set_target_sizes
238
+ )
239
+ for i in range(len(set_indices)):
240
+ if put_result[i] == 0:
241
+ exist_result[set_indices[i]] = 1
189
242
 
190
- self._put_batch_zero_copy_impl(keys, target_location, target_sizes)
243
+ success_count = 0
244
+ for i in range(len(keys)):
245
+ if exist_result[i] == 0:
246
+ break
247
+ success_count += 1
248
+ # TODO: return the number of consecutive successful operations from the start.
249
+ return success_count == len(keys)
191
250
 
192
251
  def get(
193
252
  self,
194
253
  key,
195
254
  target_location: Optional[Any] = None,
196
255
  target_sizes: Optional[Any] = None,
197
- ) -> torch.Tensor | None:
198
- assert len(key) == len(target_location) == len(target_sizes)
199
- if len(key) == 0:
200
- return
201
-
202
- for i in range(len(key)):
203
- if key[i] is None or target_location[i] is None or target_sizes[i] is None:
204
- return
205
-
206
- return self._get_batch_zero_copy_impl(key, target_location, target_sizes)
256
+ ) -> bool:
257
+ assert target_location is not None and target_sizes is not None
258
+ get_result = self._get_batch_zero_copy_impl(
259
+ [key], [target_location], [target_sizes]
260
+ )
261
+ return get_result[0] >= 0
207
262
 
208
263
  def batch_get(
209
264
  self,
210
265
  keys: List[str],
211
- target_location: Optional[Any] = None,
266
+ target_locations: Optional[Any] = None,
212
267
  target_sizes: Optional[Any] = None,
213
- ) -> torch.Tensor | None:
214
- assert len(keys) == len(target_location) == len(target_sizes)
268
+ ) -> int:
269
+ assert len(keys) == len(target_locations) == len(target_sizes)
215
270
  if len(keys) == 0:
216
- return
217
-
271
+ return 0
272
+ get_result = self._get_batch_zero_copy_impl(
273
+ keys, target_locations, target_sizes
274
+ )
275
+ if self.is_mla_backend:
276
+ key_multiplier = 1
277
+ else:
278
+ key_multiplier = 2
218
279
  for i in range(len(keys)):
219
- if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
220
- return
221
-
222
- return self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
223
-
224
- def exists(self, keys) -> bool | dict:
225
- _keys = []
226
- local_rank = get_tensor_model_parallel_rank()
227
- for key in keys:
228
- if key is None:
229
- return None
230
-
231
- if self.is_mla:
232
- _keys.append(f"{key}_k")
233
- else:
234
- _keys.append(f"{key}_{local_rank}_k")
235
- result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
236
- return result
237
-
238
- def delete(self, key) -> None:
239
- raise (NotImplementedError)
280
+ if get_result[i] < 0:
281
+ return i // key_multiplier
282
+ return len(keys) // key_multiplier
283
+
284
+ def exists(self, key) -> bool:
285
+ exist_result = self._batch_exist([key])
286
+ return exist_result[0] == 1
287
+
288
+ def batch_exists(self, keys) -> int:
289
+ if self.is_mla_backend:
290
+ query_keys = [f"{key}_k" for key in keys]
291
+ key_multiplier = 1
292
+ else:
293
+ query_keys = []
294
+ for key in keys:
295
+ query_keys.append(f"{key}_{self.local_rank}_k")
296
+ query_keys.append(f"{key}_{self.local_rank}_v")
297
+ key_multiplier = 2
298
+
299
+ exist_result = self._batch_exist(query_keys)
300
+ for i in range(len(query_keys)):
301
+ if exist_result[i] != 1:
302
+ return i // key_multiplier
303
+ return len(query_keys) // key_multiplier
240
304
 
241
305
  def close(self):
242
306
  # MooncakeDistributedStore will automatically call the destructor, so
@@ -244,22 +308,17 @@ class MooncakeStore(HiCacheStorage):
244
308
  pass
245
309
 
246
310
  def clear(self) -> None:
247
- raise (NotImplementedError)
311
+ self.store.remove_all()
248
312
 
249
313
  def _put_batch_zero_copy_impl(
250
314
  self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
251
- ) -> None:
252
- try:
253
- self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
254
- except TypeError as err:
255
- logger.error("Failed to put value to Mooncake Store: %s", err)
256
- raise TypeError("Mooncake Store Put Type Error.") from err
315
+ ) -> List[int]:
316
+ return self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
257
317
 
258
318
  def _get_batch_zero_copy_impl(
259
319
  self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
260
- ) -> None:
261
- try:
262
- self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
263
- except TypeError as err:
264
- logger.error("Failed to get value from Mooncake Store: %s", err)
265
- raise TypeError("Mooncake Store Get Type Error.") from err
320
+ ) -> List[int]:
321
+ return self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
322
+
323
+ def _batch_exist(self, key_strs: List[str]) -> List[int]:
324
+ return self.store.batch_is_exist(key_strs)
@@ -0,0 +1,161 @@
1
+ import logging
2
+ import uuid
3
+
4
+ import torch
5
+ from mooncake_store import MooncakeStore
6
+
7
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
8
+
9
+ logging.basicConfig(
10
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
11
+ )
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def generate_batch_query_keys(kv_num: int, config: HiCacheStorageConfig):
16
+ keys = []
17
+ for _ in range(kv_num):
18
+ key = "test_" + str(uuid.uuid4())
19
+ keys.append(key)
20
+ set_keys = []
21
+ for key in keys:
22
+ if config.is_mla_model:
23
+ set_keys.append(key + "_k")
24
+ else:
25
+ set_keys.append(key + f"_{config.tp_rank}_k")
26
+ set_keys.append(key + f"_{config.tp_rank}_v")
27
+ get_keys = set_keys
28
+ exist_keys = keys
29
+ return set_keys, get_keys, exist_keys
30
+
31
+
32
+ def test_single_operation():
33
+ """Test the set API with a single key-value pair."""
34
+ print("=" * 100)
35
+ print("Testing single operation")
36
+
37
+ buffer_size = 1024 * 1024 * 16 # 16MB
38
+ value_elements = 1024
39
+ store = MooncakeStore()
40
+ buffer = torch.randn(buffer_size, dtype=torch.float32)
41
+ store.register_buffer(buffer)
42
+ value_size = value_elements * buffer.element_size()
43
+
44
+ key = str(uuid.uuid4())
45
+ set_slice = buffer[:value_elements]
46
+ get_slice = buffer[value_elements : 2 * value_elements]
47
+ set_location = set_slice.data_ptr()
48
+ get_location = get_slice.data_ptr()
49
+
50
+ # Test set operation
51
+ result = store.set(key, target_location=set_location, target_sizes=value_size)
52
+ assert result is True, f"❌set operation failed for key: {key}"
53
+
54
+ # Test exists operation
55
+ assert store.exists(key), f"❌key {key} should exist after set operation"
56
+
57
+ # Test get operation
58
+ result = store.get(key, target_location=get_location, target_sizes=value_size)
59
+ assert result is True, f"❌get operation failed for key: {key}"
60
+
61
+ # Compare the data using proper tensor indices
62
+ assert torch.allclose(
63
+ set_slice, get_slice, atol=1e-6
64
+ ), f"❌get operation failed for key: {key}"
65
+
66
+ logger.info(f"✅ Single operation passed")
67
+
68
+
69
+ def test_batch_operation(config: HiCacheStorageConfig):
70
+ """Test the batch set/get APIs with multiple key-value pairs."""
71
+ print("=" * 100)
72
+ print(f"Testing batch operation with config: {config}")
73
+
74
+ buffer_size = 1024 * 1024 * 16 # 16MB
75
+ value_elements = 256
76
+ kv_num = 13
77
+ store = MooncakeStore(config)
78
+ buffer = torch.randn(buffer_size, dtype=torch.float32)
79
+ store.register_buffer(buffer)
80
+ value_size = value_elements * buffer.element_size()
81
+
82
+ set_keys, get_keys, exist_keys = generate_batch_query_keys(kv_num, config)
83
+ set_slices = [
84
+ buffer[i * value_elements : (i + 1) * value_elements]
85
+ for i in range(len(set_keys))
86
+ ]
87
+ set_locations = [set_slice.data_ptr() for set_slice in set_slices]
88
+ target_sizes = [value_size for _ in range(len(set_keys))]
89
+
90
+ # Test batch set operation
91
+ result = store.batch_set(
92
+ set_keys, target_locations=set_locations, target_sizes=target_sizes
93
+ )
94
+ assert result is True, f"❌batch set operation failed"
95
+
96
+ # Test batch exists operation
97
+ assert store.batch_exists(
98
+ exist_keys
99
+ ), f"❌keys should exist after batch set operation"
100
+
101
+ # Test batch get operation
102
+ get_slices = [
103
+ buffer[
104
+ (len(set_keys) + i)
105
+ * value_elements : (len(set_keys) + i + 1)
106
+ * value_elements
107
+ ]
108
+ for i in range(len(get_keys))
109
+ ]
110
+ get_locations = [get_slice.data_ptr() for get_slice in get_slices]
111
+ result = store.batch_get(
112
+ get_keys, target_locations=get_locations, target_sizes=target_sizes
113
+ )
114
+ assert result == kv_num, f"❌batch get operation failed"
115
+ for i in range(len(get_keys)):
116
+ assert torch.allclose(
117
+ set_slices[i], get_slices[i], atol=1e-6
118
+ ), f"❌batch get operation failed for key: {get_keys[i]}"
119
+
120
+ logger.info(f"✅ Batch operation passed")
121
+
122
+
123
+ if __name__ == "__main__":
124
+ test_single_operation()
125
+ test_batch_operation(
126
+ HiCacheStorageConfig(
127
+ is_mla_model=False,
128
+ tp_rank=0,
129
+ tp_size=1,
130
+ model_name=None,
131
+ is_page_first_layout=True,
132
+ )
133
+ )
134
+ test_batch_operation(
135
+ HiCacheStorageConfig(
136
+ is_mla_model=True,
137
+ tp_rank=0,
138
+ tp_size=1,
139
+ model_name=None,
140
+ is_page_first_layout=True,
141
+ )
142
+ )
143
+ test_batch_operation(
144
+ HiCacheStorageConfig(
145
+ is_mla_model=False,
146
+ tp_rank=1,
147
+ tp_size=4,
148
+ model_name=None,
149
+ is_page_first_layout=True,
150
+ )
151
+ )
152
+ test_batch_operation(
153
+ HiCacheStorageConfig(
154
+ is_mla_model=True,
155
+ tp_rank=3,
156
+ tp_size=8,
157
+ model_name=None,
158
+ is_page_first_layout=True,
159
+ )
160
+ )
161
+ logger.info(f"✅ All tests passed")
@@ -60,8 +60,6 @@ class TreeNode:
60
60
  self.last_access_time = time.monotonic()
61
61
 
62
62
  self.hit_count = 0
63
- # indicating the node is loading KV cache from host
64
- self.loading = False
65
63
  # store the host indices of KV cache
66
64
  self.host_value = None
67
65
 
@@ -464,7 +462,7 @@ class SWARadixCache(BasePrefixCache):
464
462
  self.req_to_token_pool.free(req.req_pool_idx)
465
463
  self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
466
464
 
467
- def cache_unfinished_req(self, req: Req) -> None:
465
+ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
468
466
  """Cache request when it is unfinished."""
469
467
  if self.disable:
470
468
  kv_indices = self.req_to_token_pool.req_to_token[