sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,280 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import threading
5
+ from typing import TYPE_CHECKING, List, Optional
6
+
7
+ import torch
8
+
9
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
10
+ from sglang.srt.mem_cache.base_prefix_cache import MatchResult
11
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
12
+ from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
13
+
14
+ try:
15
+ from lmcache.integration.sglang.sglang_adapter import (
16
+ LMCacheLayerwiseConnector,
17
+ LoadMetadata,
18
+ StoreMetadata,
19
+ )
20
+ except ImportError as e:
21
+ raise RuntimeError(
22
+ "LMCache is not installed. Please install it by running `pip install lmcache`"
23
+ ) from e
24
+
25
+ if TYPE_CHECKING:
26
+ from sglang.srt.configs.model_config import ModelConfig
27
+ from sglang.srt.managers.schedule_batch import Req
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class LayerTransferCounter:
33
+ """Minimal adapter that lets the memory pool notify LMCache per-layer.
34
+
35
+ The KV pool calls `wait_until(layer_id)` after finishing a layer, which we
36
+ translate into a `load_kv_layerwise(layer_id)` call on the LMCache connector
37
+ within the provided CUDA stream.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ num_layers: int,
43
+ load_stream: torch.cuda.Stream,
44
+ lmc_connector: LMCacheLayerwiseConnector,
45
+ printable: bool = False,
46
+ ):
47
+ self.num_layers = num_layers
48
+ self.load_stream = load_stream
49
+ self.lmc_connector = lmc_connector
50
+
51
+ def wait_until(self, layer_id: int):
52
+ # Ensure ordering of the async loads wrt compute stream(s).
53
+ self.load_stream.synchronize()
54
+ with self.load_stream:
55
+ self.lmc_connector.load_kv_layerwise(layer_id)
56
+
57
+
58
+ class LMCRadixCache(RadixCache):
59
+ """RadixCache + LMCache IO.
60
+
61
+ This subclass adds:
62
+ - LMCache connector setup (device/host buffers, TP rank/size)
63
+ - Two CUDA streams for async load/store
64
+ - Layer-wise transfer executor wiring to the KV cache
65
+ - Overridden `match_prefix` to fetch missing prefix chunks from LMCache
66
+ - Extended cache_finalization paths to store back into LMCache
67
+ - Eviction barrier that respects any in-flight host->device stores
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ req_to_token_pool: ReqToTokenPool,
73
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
74
+ page_size: int,
75
+ disable: bool = False,
76
+ enable_kv_cache_events: bool = False,
77
+ model_config: Optional["ModelConfig"] = None,
78
+ tp_size: int = 1,
79
+ rank: int = 0,
80
+ tp_group: Optional[torch.distributed.ProcessGroup] = None,
81
+ ):
82
+ super().__init__(
83
+ req_to_token_pool=req_to_token_pool,
84
+ token_to_kv_pool_allocator=token_to_kv_pool_allocator,
85
+ page_size=page_size,
86
+ disable=disable,
87
+ enable_kv_cache_events=enable_kv_cache_events,
88
+ )
89
+
90
+ kvcache = self.token_to_kv_pool_allocator.get_kvcache()
91
+ self.lmcache_connector = LMCacheLayerwiseConnector(
92
+ sgl_config=model_config,
93
+ tp_size=tp_size,
94
+ rank=rank,
95
+ # NOTE: The original implementation accessed private buffers via
96
+ # `_kvcache.k_buffer` / `.v_buffer`. We prefer public accessors when
97
+ # available; fall back to private fields if needed.
98
+ k_pool=getattr(
99
+ kvcache,
100
+ "k_buffer",
101
+ getattr(self.token_to_kv_pool_allocator._kvcache, "k_buffer"),
102
+ ),
103
+ v_pool=getattr(
104
+ kvcache,
105
+ "v_buffer",
106
+ getattr(self.token_to_kv_pool_allocator._kvcache, "v_buffer"),
107
+ ),
108
+ tp_group=tp_group,
109
+ )
110
+
111
+ self.load_stream = torch.cuda.Stream()
112
+ self.store_stream = torch.cuda.Stream()
113
+
114
+ self.layer_done_executor = LayerTransferCounter(
115
+ num_layers=(
116
+ model_config.num_hidden_layers if model_config is not None else 0
117
+ ),
118
+ load_stream=self.load_stream,
119
+ lmc_connector=self.lmcache_connector,
120
+ )
121
+ kvcache.register_layer_transfer_counter(self.layer_done_executor)
122
+
123
+ self._in_flight_nodes: list[TreeNode] = []
124
+ self._node_lock = threading.Lock()
125
+
126
+ def reset(self): # type: ignore[override]
127
+ super().reset()
128
+ if hasattr(self, "_in_flight_nodes"):
129
+ with self._node_lock:
130
+ self._in_flight_nodes.clear()
131
+
132
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override]
133
+ """Match cached prefix; if there's a tail miss, prefetch from LMCache.
134
+
135
+ Reuses the base matching logic to obtain (value, last_node). If there
136
+ remains a *page-aligned* uncached suffix and there is room (or after
137
+ eviction), we allocate token slots and trigger an async LMCache load
138
+ into those slots, then materialize a new child node for the retrieved
139
+ chunk.
140
+ """
141
+ if self.disable or not key:
142
+ return super().match_prefix(key, **kwargs)
143
+
144
+ if self.page_size != 1:
145
+ aligned_len = len(key) // self.page_size * self.page_size
146
+ key = key[:aligned_len]
147
+
148
+ base_res = super().match_prefix(key, **kwargs)
149
+ value: torch.Tensor = base_res.device_indices
150
+ last_node: TreeNode = base_res.last_device_node
151
+
152
+ if value.numel() == len(key):
153
+ return base_res
154
+
155
+ uncached_len = len(key) - value.numel()
156
+ if uncached_len == 0:
157
+ return base_res
158
+
159
+ chunk_size = self.lmcache_connector.chunk_size()
160
+ prefix_pad = value.numel() % chunk_size
161
+
162
+ if self.token_to_kv_pool_allocator.available_size() < uncached_len:
163
+ self.evict(uncached_len)
164
+
165
+ token_slots = self.token_to_kv_pool_allocator.alloc(uncached_len)
166
+ if token_slots is None:
167
+ return base_res
168
+
169
+ slot_mapping = torch.cat(
170
+ [
171
+ torch.full((value.numel(),), -1, dtype=torch.int64, device=self.device),
172
+ token_slots.detach().clone().to(torch.int64).to(self.device),
173
+ ]
174
+ )
175
+
176
+ with torch.cuda.stream(self.load_stream):
177
+ num_retrieved = self.lmcache_connector.start_load_kv(
178
+ LoadMetadata(
179
+ token_ids=key, # full page-aligned key
180
+ slot_mapping=slot_mapping,
181
+ offset=value.numel() - prefix_pad, # LMCache offset convention
182
+ )
183
+ )
184
+ logger.debug("num_retrieved_tokens: %s", num_retrieved)
185
+
186
+ if num_retrieved > 0:
187
+ self.token_to_kv_pool_allocator.free(
188
+ token_slots[(num_retrieved - prefix_pad) :]
189
+ )
190
+ else:
191
+ self.token_to_kv_pool_allocator.free(token_slots)
192
+
193
+ if num_retrieved > 0:
194
+ fetched = num_retrieved - prefix_pad
195
+ new_node = TreeNode()
196
+ start = value.numel()
197
+ end = start + fetched
198
+ new_node.key = key[start:end]
199
+ new_node.value = token_slots[:fetched]
200
+ new_node.parent = last_node
201
+ last_node.children[self.get_child_key_fn(new_node.key)] = new_node
202
+ last_node = new_node
203
+
204
+ value = torch.cat([value, token_slots[:fetched]])
205
+ self.evictable_size_ += fetched
206
+
207
+ self._record_store_event(new_node.parent)
208
+ self._record_store_event(new_node)
209
+
210
+ return MatchResult(
211
+ device_indices=value,
212
+ last_device_node=last_node,
213
+ last_host_node=last_node,
214
+ )
215
+
216
+ return base_res
217
+
218
+ def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
219
+ """On request completion, insert device KV into radix and store to LMCache."""
220
+
221
+ super().cache_finished_req(req)
222
+
223
+ token_ids = (req.origin_input_ids + req.output_ids)[:-1]
224
+ kv_indices = self.req_to_token_pool.req_to_token[
225
+ req.req_pool_idx, : len(token_ids)
226
+ ]
227
+
228
+ _, new_last_node, _, _ = self.match_prefix(token_ids)
229
+ assert new_last_node is not None
230
+
231
+ self.inc_lock_ref(new_last_node)
232
+ store_md = StoreMetadata(
233
+ last_node=new_last_node,
234
+ token_ids=token_ids,
235
+ kv_indices=kv_indices,
236
+ offset=0,
237
+ )
238
+ with torch.cuda.stream(self.store_stream):
239
+ self.lmcache_connector.store_kv(store_md)
240
+ with self._node_lock:
241
+ self._in_flight_nodes.append(new_last_node)
242
+
243
+ def evict(self, num_tokens: int) -> None: # type: ignore[override]
244
+ """Before base eviction, wait for any outstanding stores and release locks."""
245
+ if self.disable:
246
+ return
247
+
248
+ self.store_stream.synchronize()
249
+ with self._node_lock:
250
+ for node in self._in_flight_nodes:
251
+ self.dec_lock_ref(node)
252
+ self._in_flight_nodes.clear()
253
+
254
+ super().evict(num_tokens)
255
+
256
+ def pretty_print(self): # type: ignore[override]
257
+ super().pretty_print()
258
+ try:
259
+ logger.debug(
260
+ "evictable=%d protected=%d", self.evictable_size_, self.protected_size_
261
+ )
262
+ except Exception: # pragma: no cover
263
+ pass
264
+
265
+
266
+ if __name__ == "__main__":
267
+ cache = LMCRadixCache(
268
+ req_to_token_pool=None,
269
+ token_to_kv_pool_allocator=None,
270
+ page_size=1,
271
+ disable=False,
272
+ enable_kv_cache_events=False,
273
+ model_config=None,
274
+ tp_size=1,
275
+ rank=0,
276
+ tp_group=None,
277
+ )
278
+ cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64))
279
+ cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64))
280
+ cache.pretty_print()
@@ -0,0 +1,121 @@
1
+ try:
2
+ from lmcache.integration.sglang.sglang_adapter import (
3
+ LMCacheLayerwiseConnector,
4
+ LoadMetadata,
5
+ StoreMetadata,
6
+ )
7
+ except ImportError:
8
+ raise RuntimeError(
9
+ "LMCache is not installed. Please install it by running `pip install lmcache` in the root directory of LMCache"
10
+ )
11
+
12
+ import os
13
+
14
+ import torch
15
+
16
+ from sglang.srt.configs.model_config import ModelConfig
17
+
18
+ os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
19
+ os.environ["LMCACHE_CONFIG_FILE"] = "example_config.yaml"
20
+
21
+
22
+ def test_load_store_metadata():
23
+ model_config = ModelConfig(
24
+ model_path="Qwen/Qwen3-4B",
25
+ )
26
+
27
+ # Generate Dummy KV Cache
28
+ head_num = model_config.num_key_value_heads
29
+ head_dim = model_config.head_dim
30
+ layer_num = model_config.num_hidden_layers
31
+ buffer_size = 256
32
+ input_id_len = 16
33
+
34
+ k_buffer = [
35
+ torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
36
+ for _ in range(layer_num)
37
+ ]
38
+ v_buffer = [
39
+ torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
40
+ for _ in range(layer_num)
41
+ ]
42
+
43
+ connector = LMCacheLayerwiseConnector(model_config, 1, 0, k_buffer, v_buffer)
44
+
45
+ fake_token_ids = torch.randint(0, model_config.vocab_size, (input_id_len,)).tolist()
46
+ fake_kv_indices = torch.randint(0, buffer_size, (input_id_len,))
47
+ offset = 0
48
+
49
+ store_metadata = StoreMetadata(
50
+ last_node=None,
51
+ token_ids=fake_token_ids,
52
+ kv_indices=fake_kv_indices,
53
+ offset=offset,
54
+ )
55
+
56
+ load_metadata = LoadMetadata(
57
+ token_ids=fake_token_ids,
58
+ slot_mapping=fake_kv_indices,
59
+ offset=offset,
60
+ )
61
+
62
+ current_stream = torch.cuda.current_stream()
63
+
64
+ retrieve_token_num = connector.start_load_kv(load_metadata)
65
+ assert retrieve_token_num == 0
66
+
67
+ connector.store_kv(store_metadata)
68
+ current_stream.synchronize()
69
+
70
+ # check retrieve
71
+ gt_key_buffer = [
72
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
73
+ for _ in range(layer_num)
74
+ ]
75
+ gt_value_buffer = [
76
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
77
+ for _ in range(layer_num)
78
+ ]
79
+
80
+ for i in range(layer_num):
81
+ gt_key_buffer[i] = k_buffer[i][fake_kv_indices]
82
+ gt_value_buffer[i] = v_buffer[i][fake_kv_indices]
83
+
84
+ # clear the k_buffer and v_buffer
85
+ for _ in range(layer_num):
86
+ k_buffer[i].zero_()
87
+ v_buffer[i].zero_()
88
+
89
+ retrieve_token_num = connector.start_load_kv(load_metadata)
90
+ assert retrieve_token_num == input_id_len
91
+
92
+ for i in range(layer_num):
93
+ current_stream.synchronize()
94
+ connector.load_kv_layerwise(i)
95
+
96
+ current_stream.synchronize()
97
+ test_key_buffer = [
98
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
99
+ for _ in range(layer_num)
100
+ ]
101
+ test_value_buffer = [
102
+ torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
103
+ for _ in range(layer_num)
104
+ ]
105
+
106
+ for i in range(layer_num):
107
+ test_key_buffer[i] = k_buffer[i][fake_kv_indices]
108
+ test_value_buffer[i] = v_buffer[i][fake_kv_indices]
109
+
110
+ for i in range(layer_num):
111
+ assert torch.allclose(test_key_buffer[i], gt_key_buffer[i])
112
+ assert torch.allclose(test_value_buffer[i], gt_value_buffer[i])
113
+
114
+ print("================================================")
115
+ print("TEST_LOAD_STORE_METADATA PASSED!")
116
+ print("================================================")
117
+ connector.close()
118
+
119
+
120
+ if __name__ == "__main__":
121
+ test_load_store_metadata()
@@ -1,4 +1,3 @@
1
- import hashlib
2
1
  import json
3
2
  import logging
4
3
  import os
@@ -6,10 +5,8 @@ import uuid
6
5
  from dataclasses import dataclass
7
6
  from typing import Any, List, Optional
8
7
 
9
- import numpy as np
10
8
  import torch
11
9
 
12
- from sglang.srt.distributed import get_tensor_model_parallel_rank
13
10
  from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
14
11
 
15
12
  DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
@@ -75,6 +72,26 @@ class MooncakeStoreConfig:
75
72
  master_server_address=os.getenv("MOONCAKE_MASTER"),
76
73
  )
77
74
 
75
+ @staticmethod
76
+ def load_from_extra_config(extra_config: dict) -> "MooncakeStoreConfig":
77
+ """Load config from extra_config dictionary."""
78
+ if "master_server_address" not in extra_config:
79
+ raise ValueError("master_server_address is required in extra_config")
80
+
81
+ return MooncakeStoreConfig(
82
+ local_hostname=extra_config.get("local_hostname", "localhost"),
83
+ metadata_server=extra_config.get("metadata_server", "P2PHANDSHAKE"),
84
+ global_segment_size=extra_config.get(
85
+ "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
86
+ ),
87
+ local_buffer_size=extra_config.get(
88
+ "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
89
+ ),
90
+ protocol=extra_config.get("protocol", "tcp"),
91
+ device_name=extra_config.get("device_name", "auto"),
92
+ master_server_address=extra_config["master_server_address"],
93
+ )
94
+
78
95
  def __post_init__(self):
79
96
  if self.device_name == "auto":
80
97
  os.environ["MC_MS_AUTO_DISC"] = "1"
@@ -96,14 +113,39 @@ class MooncakeStore(HiCacheStorage):
96
113
 
97
114
  try:
98
115
  self.store = MooncakeDistributedStore()
99
- self.config = MooncakeStoreConfig.load_from_env()
100
- logger.info("Mooncake Configuration loaded from env successfully.")
116
+
117
+ extra_config = (
118
+ getattr(storage_config, "extra_config", None)
119
+ if storage_config
120
+ else None
121
+ )
122
+ # Load configuration with master_server_address prioritized from extra_config if available
123
+ if (
124
+ extra_config is not None
125
+ and extra_config.get("master_server_address") is not None
126
+ ):
127
+ # Load from extra_config
128
+ self.config = MooncakeStoreConfig.load_from_extra_config(extra_config)
129
+ logger.info(
130
+ "Mooncake Configuration loaded from extra_config successfully."
131
+ )
132
+ else:
133
+ # Load from environment variables
134
+ self.config = MooncakeStoreConfig.load_from_env()
135
+ logger.info("Mooncake Configuration loaded from env successfully.")
136
+
137
+ tp_scale_factor = 1 if storage_config is None else storage_config.tp_size
138
+
139
+ per_tp_global_segment_size = (
140
+ self.config.global_segment_size // tp_scale_factor
141
+ )
142
+ per_tp_local_buffer_size = self.config.local_buffer_size // tp_scale_factor
101
143
 
102
144
  ret_code = self.store.setup(
103
145
  self.config.local_hostname,
104
146
  self.config.metadata_server,
105
- self.config.global_segment_size,
106
- self.config.local_buffer_size,
147
+ per_tp_global_segment_size,
148
+ per_tp_local_buffer_size,
107
149
  self.config.protocol,
108
150
  self.config.device_name,
109
151
  self.config.master_server_address,
@@ -154,21 +196,36 @@ class MooncakeStore(HiCacheStorage):
154
196
  target_location: Optional[List[int]] = None,
155
197
  target_sizes: Optional[List[int]] = None,
156
198
  ) -> bool:
157
- return self.batch_set([key], [value], [target_location], [target_sizes])
199
+ # Only support zero copy set for now
200
+ assert target_location is not None and target_sizes is not None
201
+ exist_result = self._batch_exist([key])
202
+ if exist_result[0] == 1:
203
+ return True
204
+ put_result = self._put_batch_zero_copy_impl(
205
+ [key], [target_location], [target_sizes]
206
+ )
207
+ return put_result[0] == 0
158
208
 
159
209
  def batch_set(
160
210
  self,
161
211
  keys: List[str],
162
212
  values: Optional[List[torch.Tensor]] = None,
163
- target_location: Optional[List[int]] = None,
213
+ target_locations: Optional[List[int]] = None,
164
214
  target_sizes: Optional[List[int]] = None,
165
215
  ) -> bool:
166
- assert len(keys) == len(target_location) == len(target_sizes)
216
+ # Only support zero copy set for now
217
+ assert target_locations is not None and target_sizes is not None
218
+ assert len(keys) == len(target_locations) == len(target_sizes)
219
+
167
220
  if len(keys) == 0:
168
221
  return False
169
222
 
170
223
  for i in range(len(keys)):
171
- if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
224
+ if (
225
+ keys[i] is None
226
+ or target_locations[i] is None
227
+ or target_sizes[i] is None
228
+ ):
172
229
  return False
173
230
 
174
231
  exist_result = self._batch_exist(keys)
@@ -179,7 +236,7 @@ class MooncakeStore(HiCacheStorage):
179
236
  for i in range(len(keys)):
180
237
  if exist_result[i] != 1:
181
238
  set_keys.append(keys[i])
182
- set_target_locations.append(target_location[i])
239
+ set_target_locations.append(target_locations[i])
183
240
  set_target_sizes.append(target_sizes[i])
184
241
  set_indices.append(i)
185
242
  # Only set non-existing keys to storage
@@ -204,18 +261,24 @@ class MooncakeStore(HiCacheStorage):
204
261
  target_location: Optional[Any] = None,
205
262
  target_sizes: Optional[Any] = None,
206
263
  ) -> bool:
207
- return self.batch_get([key], [target_location], [target_sizes]) == 1
264
+ assert target_location is not None and target_sizes is not None
265
+ get_result = self._get_batch_zero_copy_impl(
266
+ [key], [target_location], [target_sizes]
267
+ )
268
+ return get_result[0] >= 0
208
269
 
209
270
  def batch_get(
210
271
  self,
211
272
  keys: List[str],
212
- target_location: Optional[Any] = None,
273
+ target_locations: Optional[Any] = None,
213
274
  target_sizes: Optional[Any] = None,
214
275
  ) -> int:
215
- assert len(keys) == len(target_location) == len(target_sizes)
276
+ assert len(keys) == len(target_locations) == len(target_sizes)
216
277
  if len(keys) == 0:
217
278
  return 0
218
- get_result = self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
279
+ get_result = self._get_batch_zero_copy_impl(
280
+ keys, target_locations, target_sizes
281
+ )
219
282
  if self.is_mla_backend:
220
283
  key_multiplier = 1
221
284
  else:
@@ -226,7 +289,8 @@ class MooncakeStore(HiCacheStorage):
226
289
  return len(keys) // key_multiplier
227
290
 
228
291
  def exists(self, key) -> bool:
229
- return self.batch_exists([key]) > 0
292
+ exist_result = self._batch_exist([key])
293
+ return exist_result[0] == 1
230
294
 
231
295
  def batch_exists(self, keys) -> int:
232
296
  if self.is_mla_backend:
@@ -245,9 +309,6 @@ class MooncakeStore(HiCacheStorage):
245
309
  return i // key_multiplier
246
310
  return len(query_keys) // key_multiplier
247
311
 
248
- def delete(self, key) -> None:
249
- raise (NotImplementedError)
250
-
251
312
  def close(self):
252
313
  # MooncakeDistributedStore will automatically call the destructor, so
253
314
  # it is unnecessary to close it manually.