sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,121 @@
1
+ try:
2
+ from lmcache.integration.sglang.sglang_adapter import (
3
+ LMCacheLayerwiseConnector,
4
+ LoadMetadata,
5
+ StoreMetadata,
6
+ )
7
+ except ImportError:
8
+ raise RuntimeError(
9
+ "LMCache is not installed. Please install it by running `pip install lmcache` in the root directory of LMCache"
10
+ )
11
+
12
+ import os
13
+
14
+ import torch
15
+
16
+ from sglang.srt.configs.model_config import ModelConfig
17
+
18
+ os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
19
+ os.environ["LMCACHE_CONFIG_FILE"] = "example_config.yaml"
20
+
21
+
22
+ def test_load_store_metadata():
23
+ model_config = ModelConfig(
24
+ model_path="Qwen/Qwen3-4B",
25
+ )
26
+
27
+ # Generate Dummy KV Cache
28
+ head_num = model_config.num_key_value_heads
29
+ head_dim = model_config.head_dim
30
+ layer_num = model_config.num_hidden_layers
31
+ buffer_size = 256
32
+ input_id_len = 16
33
+
34
+ k_buffer = [
35
+ torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
36
+ for _ in range(layer_num)
37
+ ]
38
+ v_buffer = [
39
+ torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
40
+ for _ in range(layer_num)
41
+ ]
42
+
43
+ connector = LMCacheLayerwiseConnector(model_config, 1, 0, k_buffer, v_buffer)
44
+
45
+ fake_token_ids = torch.randint(0, model_config.vocab_size, (input_id_len,)).tolist()
46
+ fake_kv_indices = torch.randint(0, buffer_size, (input_id_len,))
47
+ offset = 0
48
+
49
+ store_metadata = StoreMetadata(
50
+ last_node=None,
51
+ token_ids=fake_token_ids,
52
+ kv_indices=fake_kv_indices,
53
+ offset=offset,
54
+ )
55
+
56
+ load_metadata = LoadMetadata(
57
+ token_ids=fake_token_ids,
58
+ slot_mapping=fake_kv_indices,
59
+ offset=offset,
60
+ )
61
+
62
+ current_stream = torch.cuda.current_stream()
63
+
64
+ retrieve_token_num = connector.start_load_kv(load_metadata)
65
+ assert retrieve_token_num == 0
66
+
67
+ connector.store_kv(store_metadata)
68
+ current_stream.synchronize()
69
+
70
+ # check retrieve
71
+ gt_key_buffer = [
72
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
73
+ for _ in range(layer_num)
74
+ ]
75
+ gt_value_buffer = [
76
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
77
+ for _ in range(layer_num)
78
+ ]
79
+
80
+ for i in range(layer_num):
81
+ gt_key_buffer[i] = k_buffer[i][fake_kv_indices]
82
+ gt_value_buffer[i] = v_buffer[i][fake_kv_indices]
83
+
84
+ # clear the k_buffer and v_buffer
85
+ for _ in range(layer_num):
86
+ k_buffer[i].zero_()
87
+ v_buffer[i].zero_()
88
+
89
+ retrieve_token_num = connector.start_load_kv(load_metadata)
90
+ assert retrieve_token_num == input_id_len
91
+
92
+ for i in range(layer_num):
93
+ current_stream.synchronize()
94
+ connector.load_kv_layerwise(i)
95
+
96
+ current_stream.synchronize()
97
+ test_key_buffer = [
98
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
99
+ for _ in range(layer_num)
100
+ ]
101
+ test_value_buffer = [
102
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
103
+ for _ in range(layer_num)
104
+ ]
105
+
106
+ for i in range(layer_num):
107
+ test_key_buffer[i] = k_buffer[i][fake_kv_indices]
108
+ test_value_buffer[i] = v_buffer[i][fake_kv_indices]
109
+
110
+ for i in range(layer_num):
111
+ assert torch.allclose(test_key_buffer[i], gt_key_buffer[i])
112
+ assert torch.allclose(test_value_buffer[i], gt_value_buffer[i])
113
+
114
+ print("================================================")
115
+ print("TEST_LOAD_STORE_METADATA PASSED!")
116
+ print("================================================")
117
+ connector.close()
118
+
119
+
120
+ if __name__ == "__main__":
121
+ test_load_store_metadata()
@@ -1,4 +1,3 @@
1
- import hashlib
2
1
  import json
3
2
  import logging
4
3
  import os
@@ -6,10 +5,8 @@ import uuid
6
5
  from dataclasses import dataclass
7
6
  from typing import Any, List, Optional
8
7
 
9
- import numpy as np
10
8
  import torch
11
9
 
12
- from sglang.srt.distributed import get_tensor_model_parallel_rank
13
10
  from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
14
11
 
15
12
  DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
@@ -75,6 +72,26 @@ class MooncakeStoreConfig:
75
72
  master_server_address=os.getenv("MOONCAKE_MASTER"),
76
73
  )
77
74
 
75
+ @staticmethod
76
+ def load_from_extra_config(extra_config: dict) -> "MooncakeStoreConfig":
77
+ """Load config from extra_config dictionary."""
78
+ if "master_server_address" not in extra_config:
79
+ raise ValueError("master_server_address is required in extra_config")
80
+
81
+ return MooncakeStoreConfig(
82
+ local_hostname=extra_config.get("local_hostname", "localhost"),
83
+ metadata_server=extra_config.get("metadata_server", "P2PHANDSHAKE"),
84
+ global_segment_size=extra_config.get(
85
+ "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
86
+ ),
87
+ local_buffer_size=extra_config.get(
88
+ "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
89
+ ),
90
+ protocol=extra_config.get("protocol", "tcp"),
91
+ device_name=extra_config.get("device_name", "auto"),
92
+ master_server_address=extra_config["master_server_address"],
93
+ )
94
+
78
95
  def __post_init__(self):
79
96
  if self.device_name == "auto":
80
97
  os.environ["MC_MS_AUTO_DISC"] = "1"
@@ -96,8 +113,26 @@ class MooncakeStore(HiCacheStorage):
96
113
 
97
114
  try:
98
115
  self.store = MooncakeDistributedStore()
99
- self.config = MooncakeStoreConfig.load_from_env()
100
- logger.info("Mooncake Configuration loaded from env successfully.")
116
+
117
+ extra_config = (
118
+ getattr(storage_config, "extra_config", None)
119
+ if storage_config
120
+ else None
121
+ )
122
+ # Load configuration with master_server_address prioritized from extra_config if available
123
+ if (
124
+ extra_config is not None
125
+ and extra_config.get("master_server_address") is not None
126
+ ):
127
+ # Load from extra_config
128
+ self.config = MooncakeStoreConfig.load_from_extra_config(extra_config)
129
+ logger.info(
130
+ "Mooncake Configuration loaded from extra_config successfully."
131
+ )
132
+ else:
133
+ # Load from environment variables
134
+ self.config = MooncakeStoreConfig.load_from_env()
135
+ logger.info("Mooncake Configuration loaded from env successfully.")
101
136
 
102
137
  ret_code = self.store.setup(
103
138
  self.config.local_hostname,
@@ -154,20 +189,36 @@ class MooncakeStore(HiCacheStorage):
154
189
  target_location: Optional[List[int]] = None,
155
190
  target_sizes: Optional[List[int]] = None,
156
191
  ) -> bool:
157
- return self.batch_set([key], [value], [target_location], [target_sizes])
192
+ # Only support zero copy set for now
193
+ assert target_location is not None and target_sizes is not None
194
+ exist_result = self._batch_exist([key])
195
+ if exist_result[0] == 1:
196
+ return True
197
+ put_result = self._put_batch_zero_copy_impl(
198
+ [key], [target_location], [target_sizes]
199
+ )
200
+ return put_result[0] == 0
158
201
 
159
202
  def batch_set(
160
203
  self,
161
204
  keys: List[str],
162
- target_location: Optional[List[int]] = None,
205
+ values: Optional[List[torch.Tensor]] = None,
206
+ target_locations: Optional[List[int]] = None,
163
207
  target_sizes: Optional[List[int]] = None,
164
208
  ) -> bool:
165
- assert len(keys) == len(target_location) == len(target_sizes)
209
+ # Only support zero copy set for now
210
+ assert target_locations is not None and target_sizes is not None
211
+ assert len(keys) == len(target_locations) == len(target_sizes)
212
+
166
213
  if len(keys) == 0:
167
214
  return False
168
215
 
169
216
  for i in range(len(keys)):
170
- if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
217
+ if (
218
+ keys[i] is None
219
+ or target_locations[i] is None
220
+ or target_sizes[i] is None
221
+ ):
171
222
  return False
172
223
 
173
224
  exist_result = self._batch_exist(keys)
@@ -178,7 +229,7 @@ class MooncakeStore(HiCacheStorage):
178
229
  for i in range(len(keys)):
179
230
  if exist_result[i] != 1:
180
231
  set_keys.append(keys[i])
181
- set_target_locations.append(target_location[i])
232
+ set_target_locations.append(target_locations[i])
182
233
  set_target_sizes.append(target_sizes[i])
183
234
  set_indices.append(i)
184
235
  # Only set non-existing keys to storage
@@ -203,18 +254,24 @@ class MooncakeStore(HiCacheStorage):
203
254
  target_location: Optional[Any] = None,
204
255
  target_sizes: Optional[Any] = None,
205
256
  ) -> bool:
206
- return self.batch_get([key], [target_location], [target_sizes]) == 1
257
+ assert target_location is not None and target_sizes is not None
258
+ get_result = self._get_batch_zero_copy_impl(
259
+ [key], [target_location], [target_sizes]
260
+ )
261
+ return get_result[0] >= 0
207
262
 
208
263
  def batch_get(
209
264
  self,
210
265
  keys: List[str],
211
- target_location: Optional[Any] = None,
266
+ target_locations: Optional[Any] = None,
212
267
  target_sizes: Optional[Any] = None,
213
268
  ) -> int:
214
- assert len(keys) == len(target_location) == len(target_sizes)
269
+ assert len(keys) == len(target_locations) == len(target_sizes)
215
270
  if len(keys) == 0:
216
271
  return 0
217
- get_result = self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
272
+ get_result = self._get_batch_zero_copy_impl(
273
+ keys, target_locations, target_sizes
274
+ )
218
275
  if self.is_mla_backend:
219
276
  key_multiplier = 1
220
277
  else:
@@ -225,7 +282,8 @@ class MooncakeStore(HiCacheStorage):
225
282
  return len(keys) // key_multiplier
226
283
 
227
284
  def exists(self, key) -> bool:
228
- return self.batch_exists([key]) > 0
285
+ exist_result = self._batch_exist([key])
286
+ return exist_result[0] == 1
229
287
 
230
288
  def batch_exists(self, keys) -> int:
231
289
  if self.is_mla_backend:
@@ -244,16 +302,13 @@ class MooncakeStore(HiCacheStorage):
244
302
  return i // key_multiplier
245
303
  return len(query_keys) // key_multiplier
246
304
 
247
- def delete(self, key) -> None:
248
- raise (NotImplementedError)
249
-
250
305
  def close(self):
251
306
  # MooncakeDistributedStore will automatically call the destructor, so
252
307
  # it is unnecessary to close it manually.
253
308
  pass
254
309
 
255
310
  def clear(self) -> None:
256
- raise (NotImplementedError)
311
+ self.store.remove_all()
257
312
 
258
313
  def _put_batch_zero_copy_impl(
259
314
  self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
@@ -0,0 +1,161 @@
1
+ import logging
2
+ import uuid
3
+
4
+ import torch
5
+ from mooncake_store import MooncakeStore
6
+
7
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
8
+
9
+ logging.basicConfig(
10
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
11
+ )
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def generate_batch_query_keys(kv_num: int, config: HiCacheStorageConfig):
16
+ keys = []
17
+ for _ in range(kv_num):
18
+ key = "test_" + str(uuid.uuid4())
19
+ keys.append(key)
20
+ set_keys = []
21
+ for key in keys:
22
+ if config.is_mla_model:
23
+ set_keys.append(key + "_k")
24
+ else:
25
+ set_keys.append(key + f"_{config.tp_rank}_k")
26
+ set_keys.append(key + f"_{config.tp_rank}_v")
27
+ get_keys = set_keys
28
+ exist_keys = keys
29
+ return set_keys, get_keys, exist_keys
30
+
31
+
32
+ def test_single_operation():
33
+ """Test the set API with a single key-value pair."""
34
+ print("=" * 100)
35
+ print("Testing single operation")
36
+
37
+ buffer_size = 1024 * 1024 * 16 # 16MB
38
+ value_elements = 1024
39
+ store = MooncakeStore()
40
+ buffer = torch.randn(buffer_size, dtype=torch.float32)
41
+ store.register_buffer(buffer)
42
+ value_size = value_elements * buffer.element_size()
43
+
44
+ key = str(uuid.uuid4())
45
+ set_slice = buffer[:value_elements]
46
+ get_slice = buffer[value_elements : 2 * value_elements]
47
+ set_location = set_slice.data_ptr()
48
+ get_location = get_slice.data_ptr()
49
+
50
+ # Test set operation
51
+ result = store.set(key, target_location=set_location, target_sizes=value_size)
52
+ assert result is True, f"❌set operation failed for key: {key}"
53
+
54
+ # Test exists operation
55
+ assert store.exists(key), f"❌key {key} should exist after set operation"
56
+
57
+ # Test get operation
58
+ result = store.get(key, target_location=get_location, target_sizes=value_size)
59
+ assert result is True, f"❌get operation failed for key: {key}"
60
+
61
+ # Compare the data using proper tensor indices
62
+ assert torch.allclose(
63
+ set_slice, get_slice, atol=1e-6
64
+ ), f"❌get operation failed for key: {key}"
65
+
66
+ logger.info(f"✅ Single operation passed")
67
+
68
+
69
+ def test_batch_operation(config: HiCacheStorageConfig):
70
+ """Test the batch set/get APIs with multiple key-value pairs."""
71
+ print("=" * 100)
72
+ print(f"Testing batch operation with config: {config}")
73
+
74
+ buffer_size = 1024 * 1024 * 16 # 16MB
75
+ value_elements = 256
76
+ kv_num = 13
77
+ store = MooncakeStore(config)
78
+ buffer = torch.randn(buffer_size, dtype=torch.float32)
79
+ store.register_buffer(buffer)
80
+ value_size = value_elements * buffer.element_size()
81
+
82
+ set_keys, get_keys, exist_keys = generate_batch_query_keys(kv_num, config)
83
+ set_slices = [
84
+ buffer[i * value_elements : (i + 1) * value_elements]
85
+ for i in range(len(set_keys))
86
+ ]
87
+ set_locations = [set_slice.data_ptr() for set_slice in set_slices]
88
+ target_sizes = [value_size for _ in range(len(set_keys))]
89
+
90
+ # Test batch set operation
91
+ result = store.batch_set(
92
+ set_keys, target_locations=set_locations, target_sizes=target_sizes
93
+ )
94
+ assert result is True, f"❌batch set operation failed"
95
+
96
+ # Test batch exists operation
97
+ assert store.batch_exists(
98
+ exist_keys
99
+ ), f"❌keys should exist after batch set operation"
100
+
101
+ # Test batch get operation
102
+ get_slices = [
103
+ buffer[
104
+ (len(set_keys) + i)
105
+ * value_elements : (len(set_keys) + i + 1)
106
+ * value_elements
107
+ ]
108
+ for i in range(len(get_keys))
109
+ ]
110
+ get_locations = [get_slice.data_ptr() for get_slice in get_slices]
111
+ result = store.batch_get(
112
+ get_keys, target_locations=get_locations, target_sizes=target_sizes
113
+ )
114
+ assert result == kv_num, f"❌batch get operation failed"
115
+ for i in range(len(get_keys)):
116
+ assert torch.allclose(
117
+ set_slices[i], get_slices[i], atol=1e-6
118
+ ), f"❌batch get operation failed for key: {get_keys[i]}"
119
+
120
+ logger.info(f"✅ Batch operation passed")
121
+
122
+
123
+ if __name__ == "__main__":
124
+ test_single_operation()
125
+ test_batch_operation(
126
+ HiCacheStorageConfig(
127
+ is_mla_model=False,
128
+ tp_rank=0,
129
+ tp_size=1,
130
+ model_name=None,
131
+ is_page_first_layout=True,
132
+ )
133
+ )
134
+ test_batch_operation(
135
+ HiCacheStorageConfig(
136
+ is_mla_model=True,
137
+ tp_rank=0,
138
+ tp_size=1,
139
+ model_name=None,
140
+ is_page_first_layout=True,
141
+ )
142
+ )
143
+ test_batch_operation(
144
+ HiCacheStorageConfig(
145
+ is_mla_model=False,
146
+ tp_rank=1,
147
+ tp_size=4,
148
+ model_name=None,
149
+ is_page_first_layout=True,
150
+ )
151
+ )
152
+ test_batch_operation(
153
+ HiCacheStorageConfig(
154
+ is_mla_model=True,
155
+ tp_rank=3,
156
+ tp_size=8,
157
+ model_name=None,
158
+ is_page_first_layout=True,
159
+ )
160
+ )
161
+ logger.info(f"✅ All tests passed")
@@ -60,8 +60,6 @@ class TreeNode:
60
60
  self.last_access_time = time.monotonic()
61
61
 
62
62
  self.hit_count = 0
63
- # indicating the node is loading KV cache from host
64
- self.loading = False
65
63
  # store the host indices of KV cache
66
64
  self.host_value = None
67
65
 
@@ -464,7 +462,7 @@ class SWARadixCache(BasePrefixCache):
464
462
  self.req_to_token_pool.free(req.req_pool_idx)
465
463
  self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
466
464
 
467
- def cache_unfinished_req(self, req: Req) -> None:
465
+ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
468
466
  """Cache request when it is unfinished."""
469
467
  if self.disable:
470
468
  kv_indices = self.req_to_token_pool.req_to_token[