sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
1
- import hashlib
2
1
  import json
3
2
  import logging
4
3
  import os
@@ -6,15 +5,18 @@ import uuid
6
5
  from dataclasses import dataclass
7
6
  from typing import Any, List, Optional
8
7
 
9
- import numpy as np
10
8
  import torch
11
9
 
12
- from sglang.srt.distributed import get_tensor_model_parallel_rank
13
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
10
+ from sglang.srt.mem_cache.hicache_storage import (
11
+ HiCacheStorage,
12
+ HiCacheStorageConfig,
13
+ HiCacheStorageExtraInfo,
14
+ )
15
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
14
16
 
15
17
  DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
16
18
  DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
17
-
19
+ DEFAULT_MOONCAKE_CONFIG_PATH_ENV = "SGLANG_HICACHE_MOONCAKE_CONFIG_PATH"
18
20
  logger = logging.getLogger(__name__)
19
21
 
20
22
 
@@ -31,13 +33,13 @@ class MooncakeStoreConfig:
31
33
  @staticmethod
32
34
  def from_file() -> "MooncakeStoreConfig":
33
35
  """Load the config from a JSON file."""
34
- file_path = os.getenv("MOONCAKE_CONFIG_PATH")
35
- if file_path is None:
36
- raise ValueError(
37
- "The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
38
- )
39
- with open(file_path) as fin:
40
- config = json.load(fin)
36
+ file_path = os.getenv(DEFAULT_MOONCAKE_CONFIG_PATH_ENV)
37
+ try:
38
+ with open(file_path) as fin:
39
+ config = json.load(fin)
40
+ except Exception as e:
41
+ raise RuntimeError(f"Failed to load config from {file_path}: {str(e)}")
42
+
41
43
  return MooncakeStoreConfig(
42
44
  local_hostname=config.get("local_hostname"),
43
45
  metadata_server=config.get("metadata_server"),
@@ -75,6 +77,26 @@ class MooncakeStoreConfig:
75
77
  master_server_address=os.getenv("MOONCAKE_MASTER"),
76
78
  )
77
79
 
80
+ @staticmethod
81
+ def load_from_extra_config(extra_config: dict) -> "MooncakeStoreConfig":
82
+ """Load config from extra_config dictionary."""
83
+ if "master_server_address" not in extra_config:
84
+ raise ValueError("master_server_address is required in extra_config")
85
+
86
+ return MooncakeStoreConfig(
87
+ local_hostname=extra_config.get("local_hostname", "localhost"),
88
+ metadata_server=extra_config.get("metadata_server", "P2PHANDSHAKE"),
89
+ global_segment_size=extra_config.get(
90
+ "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
91
+ ),
92
+ local_buffer_size=extra_config.get(
93
+ "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
94
+ ),
95
+ protocol=extra_config.get("protocol", "tcp"),
96
+ device_name=extra_config.get("device_name", "auto"),
97
+ master_server_address=extra_config["master_server_address"],
98
+ )
99
+
78
100
  def __post_init__(self):
79
101
  if self.device_name == "auto":
80
102
  os.environ["MC_MS_AUTO_DISC"] = "1"
@@ -84,6 +106,7 @@ class MooncakeStoreConfig:
84
106
 
85
107
 
86
108
  class MooncakeStore(HiCacheStorage):
109
+
87
110
  def __init__(self, storage_config: HiCacheStorageConfig = None):
88
111
  try:
89
112
  from mooncake.store import MooncakeDistributedStore
@@ -96,14 +119,43 @@ class MooncakeStore(HiCacheStorage):
96
119
 
97
120
  try:
98
121
  self.store = MooncakeDistributedStore()
99
- self.config = MooncakeStoreConfig.load_from_env()
100
- logger.info("Mooncake Configuration loaded from env successfully.")
122
+
123
+ extra_config = (
124
+ getattr(storage_config, "extra_config", None)
125
+ if storage_config
126
+ else None
127
+ )
128
+ # Load configuration with master_server_address prioritized from extra_config if available
129
+ if (
130
+ extra_config is not None
131
+ and extra_config.get("master_server_address") is not None
132
+ ):
133
+ # Load from extra_config
134
+ self.config = MooncakeStoreConfig.load_from_extra_config(extra_config)
135
+ logger.info(
136
+ "Mooncake Configuration loaded from extra_config successfully."
137
+ )
138
+ elif os.getenv(DEFAULT_MOONCAKE_CONFIG_PATH_ENV):
139
+ # Load from config file
140
+ self.config = MooncakeStoreConfig.from_file()
141
+ logger.info("Mooncake Configuration loaded from file successfully.")
142
+ else:
143
+ # Load from environment variables
144
+ self.config = MooncakeStoreConfig.load_from_env()
145
+ logger.info("Mooncake Configuration loaded from env successfully.")
146
+
147
+ tp_scale_factor = 1 if storage_config is None else storage_config.tp_size
148
+
149
+ per_tp_global_segment_size = (
150
+ self.config.global_segment_size // tp_scale_factor
151
+ )
152
+ per_tp_local_buffer_size = self.config.local_buffer_size // tp_scale_factor
101
153
 
102
154
  ret_code = self.store.setup(
103
155
  self.config.local_hostname,
104
156
  self.config.metadata_server,
105
- self.config.global_segment_size,
106
- self.config.local_buffer_size,
157
+ per_tp_global_segment_size,
158
+ per_tp_local_buffer_size,
107
159
  self.config.protocol,
108
160
  self.config.device_name,
109
161
  self.config.master_server_address,
@@ -136,7 +188,13 @@ class MooncakeStore(HiCacheStorage):
136
188
  assert self.store.is_exist(warmup_key) == 1
137
189
  assert self.store.get(warmup_key) == warmup_value
138
190
 
139
- def register_buffer(self, buffer: torch.Tensor) -> None:
191
+ def register_mem_pool_host(self, mem_pool_host: HostKVCache):
192
+ super().register_mem_pool_host(mem_pool_host)
193
+ assert self.mem_pool_host.layout in [
194
+ "page_first",
195
+ "page_first_direct",
196
+ ], "mooncake store storage backend only support page first or page first direct layout"
197
+ buffer = self.mem_pool_host.kv_buffer
140
198
  try:
141
199
  buffer_ptr = buffer.data_ptr()
142
200
  buffer_size = buffer.numel() * buffer.element_size()
@@ -147,6 +205,97 @@ class MooncakeStore(HiCacheStorage):
147
205
  logger.error("Failed to register buffer to Mooncake Store: %s", err)
148
206
  raise TypeError("Mooncake Store Register Buffer Error.") from err
149
207
 
208
+ def _get_mha_buffer_meta(self, keys, indices):
209
+ ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
210
+ key_list = []
211
+ for key_ in keys:
212
+ key_list.append(f"{key_}_{self.local_rank}_k")
213
+ key_list.append(f"{key_}_{self.local_rank}_v")
214
+ assert len(key_list) == len(ptr_list)
215
+ return key_list, ptr_list, element_size_list
216
+
217
+ def _get_mla_buffer_meta(self, keys, indices):
218
+ ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
219
+ key_list = []
220
+ for key_ in keys:
221
+ key_list.append(f"{key_}_k")
222
+ assert len(key_list) == len(ptr_list)
223
+ return key_list, ptr_list, element_size_list
224
+
225
+ def _batch_preprocess(self, keys, host_indices):
226
+ assert len(keys) > 0
227
+ assert len(keys) == len(host_indices) // self.mem_pool_host.page_size
228
+ if self.is_mla_backend:
229
+ return self._get_mla_buffer_meta(keys, host_indices)
230
+ else:
231
+ return self._get_mha_buffer_meta(keys, host_indices)
232
+
233
+ def _batch_postprocess(self, results: List[int], is_set_operate=False):
234
+ """
235
+ refer to https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-store/include/pybind_client.h
236
+ for batch_get_into, results is Vector of integers,
237
+ where each element is the number of bytes read on success, or a negative value on error
238
+ for batch_put_from, results is Vector of integers,
239
+ where each element is 0 on success, or a negative value on error
240
+ """
241
+ if self.is_mla_backend:
242
+ return [k_res == 0 if is_set_operate else k_res > 0 for k_res in results]
243
+ else:
244
+ kv_pairs = zip(results[::2], results[1::2])
245
+ return [
246
+ (
247
+ (k_res == 0 and v_res == 0)
248
+ if is_set_operate
249
+ else (k_res > 0 and v_res > 0)
250
+ )
251
+ for k_res, v_res in kv_pairs
252
+ ]
253
+
254
+ def batch_get_v1(
255
+ self,
256
+ keys: List[str],
257
+ host_indices: torch.Tensor,
258
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
259
+ ) -> List[bool]:
260
+ key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
261
+ get_results = self._get_batch_zero_copy_impl(
262
+ key_strs, buffer_ptrs, buffer_sizes
263
+ )
264
+ return self._batch_postprocess(get_results, is_set_operate=False)
265
+
266
+ def batch_set_v1(
267
+ self,
268
+ keys: List[str],
269
+ host_indices: torch.Tensor,
270
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
271
+ ) -> List[bool]:
272
+ key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
273
+ exist_result = self._batch_exist(key_strs)
274
+
275
+ set_keys = []
276
+ set_buffer_ptrs = []
277
+ set_buffer_sizes = []
278
+ set_indices = []
279
+ set_results = [-1] * len(key_strs)
280
+ for i in range(len(key_strs)):
281
+ if exist_result[i] != 1:
282
+ set_keys.append(key_strs[i])
283
+ set_buffer_ptrs.append(buffer_ptrs[i])
284
+ set_buffer_sizes.append(buffer_sizes[i])
285
+ set_indices.append(i)
286
+ else:
287
+ set_results[i] = 0
288
+
289
+ # Only set non-existing keys to storage
290
+ if len(set_keys) > 0:
291
+ put_results = self._put_batch_zero_copy_impl(
292
+ set_keys, set_buffer_ptrs, set_buffer_sizes
293
+ )
294
+ for i in range(len(set_indices)):
295
+ set_results[set_indices[i]] = put_results[i]
296
+
297
+ return self._batch_postprocess(set_results, is_set_operate=True)
298
+
150
299
  def set(
151
300
  self,
152
301
  key,
@@ -154,21 +303,36 @@ class MooncakeStore(HiCacheStorage):
154
303
  target_location: Optional[List[int]] = None,
155
304
  target_sizes: Optional[List[int]] = None,
156
305
  ) -> bool:
157
- return self.batch_set([key], [value], [target_location], [target_sizes])
306
+ # Only support zero copy set for now
307
+ assert target_location is not None and target_sizes is not None
308
+ exist_result = self._batch_exist([key])
309
+ if exist_result[0] == 1:
310
+ return True
311
+ put_result = self._put_batch_zero_copy_impl(
312
+ [key], [target_location], [target_sizes]
313
+ )
314
+ return put_result[0] == 0
158
315
 
159
316
  def batch_set(
160
317
  self,
161
318
  keys: List[str],
162
319
  values: Optional[List[torch.Tensor]] = None,
163
- target_location: Optional[List[int]] = None,
320
+ target_locations: Optional[List[int]] = None,
164
321
  target_sizes: Optional[List[int]] = None,
165
322
  ) -> bool:
166
- assert len(keys) == len(target_location) == len(target_sizes)
323
+ # Only support zero copy set for now
324
+ assert target_locations is not None and target_sizes is not None
325
+ assert len(keys) == len(target_locations) == len(target_sizes)
326
+
167
327
  if len(keys) == 0:
168
328
  return False
169
329
 
170
330
  for i in range(len(keys)):
171
- if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
331
+ if (
332
+ keys[i] is None
333
+ or target_locations[i] is None
334
+ or target_sizes[i] is None
335
+ ):
172
336
  return False
173
337
 
174
338
  exist_result = self._batch_exist(keys)
@@ -179,7 +343,7 @@ class MooncakeStore(HiCacheStorage):
179
343
  for i in range(len(keys)):
180
344
  if exist_result[i] != 1:
181
345
  set_keys.append(keys[i])
182
- set_target_locations.append(target_location[i])
346
+ set_target_locations.append(target_locations[i])
183
347
  set_target_sizes.append(target_sizes[i])
184
348
  set_indices.append(i)
185
349
  # Only set non-existing keys to storage
@@ -204,18 +368,24 @@ class MooncakeStore(HiCacheStorage):
204
368
  target_location: Optional[Any] = None,
205
369
  target_sizes: Optional[Any] = None,
206
370
  ) -> bool:
207
- return self.batch_get([key], [target_location], [target_sizes]) == 1
371
+ assert target_location is not None and target_sizes is not None
372
+ get_result = self._get_batch_zero_copy_impl(
373
+ [key], [target_location], [target_sizes]
374
+ )
375
+ return get_result[0] >= 0
208
376
 
209
377
  def batch_get(
210
378
  self,
211
379
  keys: List[str],
212
- target_location: Optional[Any] = None,
380
+ target_locations: Optional[Any] = None,
213
381
  target_sizes: Optional[Any] = None,
214
382
  ) -> int:
215
- assert len(keys) == len(target_location) == len(target_sizes)
383
+ assert len(keys) == len(target_locations) == len(target_sizes)
216
384
  if len(keys) == 0:
217
385
  return 0
218
- get_result = self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
386
+ get_result = self._get_batch_zero_copy_impl(
387
+ keys, target_locations, target_sizes
388
+ )
219
389
  if self.is_mla_backend:
220
390
  key_multiplier = 1
221
391
  else:
@@ -226,7 +396,8 @@ class MooncakeStore(HiCacheStorage):
226
396
  return len(keys) // key_multiplier
227
397
 
228
398
  def exists(self, key) -> bool:
229
- return self.batch_exists([key]) > 0
399
+ exist_result = self._batch_exist([key])
400
+ return exist_result[0] == 1
230
401
 
231
402
  def batch_exists(self, keys) -> int:
232
403
  if self.is_mla_backend:
@@ -245,9 +416,6 @@ class MooncakeStore(HiCacheStorage):
245
416
  return i // key_multiplier
246
417
  return len(query_keys) // key_multiplier
247
418
 
248
- def delete(self, key) -> None:
249
- raise (NotImplementedError)
250
-
251
419
  def close(self):
252
420
  # MooncakeDistributedStore will automatically call the destructor, so
253
421
  # it is unnecessary to close it manually.
@@ -0,0 +1,161 @@
1
+ import logging
2
+ import uuid
3
+
4
+ import torch
5
+ from mooncake_store import MooncakeStore
6
+
7
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
8
+
9
+ logging.basicConfig(
10
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
11
+ )
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def generate_batch_query_keys(kv_num: int, config: HiCacheStorageConfig):
16
+ keys = []
17
+ for _ in range(kv_num):
18
+ key = "test_" + str(uuid.uuid4())
19
+ keys.append(key)
20
+ set_keys = []
21
+ for key in keys:
22
+ if config.is_mla_model:
23
+ set_keys.append(key + "_k")
24
+ else:
25
+ set_keys.append(key + f"_{config.tp_rank}_k")
26
+ set_keys.append(key + f"_{config.tp_rank}_v")
27
+ get_keys = set_keys
28
+ exist_keys = keys
29
+ return set_keys, get_keys, exist_keys
30
+
31
+
32
+ def test_single_operation():
33
+ """Test the set API with a single key-value pair."""
34
+ print("=" * 100)
35
+ print("Testing single operation")
36
+
37
+ buffer_size = 1024 * 1024 * 16 # 16MB
38
+ value_elements = 1024
39
+ store = MooncakeStore()
40
+ buffer = torch.randn(buffer_size, dtype=torch.float32)
41
+ store.register_buffer(buffer)
42
+ value_size = value_elements * buffer.element_size()
43
+
44
+ key = str(uuid.uuid4())
45
+ set_slice = buffer[:value_elements]
46
+ get_slice = buffer[value_elements : 2 * value_elements]
47
+ set_location = set_slice.data_ptr()
48
+ get_location = get_slice.data_ptr()
49
+
50
+ # Test set operation
51
+ result = store.set(key, target_location=set_location, target_sizes=value_size)
52
+ assert result is True, f"❌set operation failed for key: {key}"
53
+
54
+ # Test exists operation
55
+ assert store.exists(key), f"❌key {key} should exist after set operation"
56
+
57
+ # Test get operation
58
+ result = store.get(key, target_location=get_location, target_sizes=value_size)
59
+ assert result is True, f"❌get operation failed for key: {key}"
60
+
61
+ # Compare the data using proper tensor indices
62
+ assert torch.allclose(
63
+ set_slice, get_slice, atol=1e-6
64
+ ), f"❌get operation failed for key: {key}"
65
+
66
+ logger.info(f"✅ Single operation passed")
67
+
68
+
69
+ def test_batch_operation(config: HiCacheStorageConfig):
70
+ """Test the batch set/get APIs with multiple key-value pairs."""
71
+ print("=" * 100)
72
+ print(f"Testing batch operation with config: {config}")
73
+
74
+ buffer_size = 1024 * 1024 * 16 # 16MB
75
+ value_elements = 256
76
+ kv_num = 13
77
+ store = MooncakeStore(config)
78
+ buffer = torch.randn(buffer_size, dtype=torch.float32)
79
+ store.register_buffer(buffer)
80
+ value_size = value_elements * buffer.element_size()
81
+
82
+ set_keys, get_keys, exist_keys = generate_batch_query_keys(kv_num, config)
83
+ set_slices = [
84
+ buffer[i * value_elements : (i + 1) * value_elements]
85
+ for i in range(len(set_keys))
86
+ ]
87
+ set_locations = [set_slice.data_ptr() for set_slice in set_slices]
88
+ target_sizes = [value_size for _ in range(len(set_keys))]
89
+
90
+ # Test batch set operation
91
+ result = store.batch_set(
92
+ set_keys, target_locations=set_locations, target_sizes=target_sizes
93
+ )
94
+ assert result is True, f"❌batch set operation failed"
95
+
96
+ # Test batch exists operation
97
+ assert store.batch_exists(
98
+ exist_keys
99
+ ), f"❌keys should exist after batch set operation"
100
+
101
+ # Test batch get operation
102
+ get_slices = [
103
+ buffer[
104
+ (len(set_keys) + i)
105
+ * value_elements : (len(set_keys) + i + 1)
106
+ * value_elements
107
+ ]
108
+ for i in range(len(get_keys))
109
+ ]
110
+ get_locations = [get_slice.data_ptr() for get_slice in get_slices]
111
+ result = store.batch_get(
112
+ get_keys, target_locations=get_locations, target_sizes=target_sizes
113
+ )
114
+ assert result == kv_num, f"❌batch get operation failed"
115
+ for i in range(len(get_keys)):
116
+ assert torch.allclose(
117
+ set_slices[i], get_slices[i], atol=1e-6
118
+ ), f"❌batch get operation failed for key: {get_keys[i]}"
119
+
120
+ logger.info(f"✅ Batch operation passed")
121
+
122
+
123
+ if __name__ == "__main__":
124
+ test_single_operation()
125
+ test_batch_operation(
126
+ HiCacheStorageConfig(
127
+ is_mla_model=False,
128
+ tp_rank=0,
129
+ tp_size=1,
130
+ model_name=None,
131
+ is_page_first_layout=True,
132
+ )
133
+ )
134
+ test_batch_operation(
135
+ HiCacheStorageConfig(
136
+ is_mla_model=True,
137
+ tp_rank=0,
138
+ tp_size=1,
139
+ model_name=None,
140
+ is_page_first_layout=True,
141
+ )
142
+ )
143
+ test_batch_operation(
144
+ HiCacheStorageConfig(
145
+ is_mla_model=False,
146
+ tp_rank=1,
147
+ tp_size=4,
148
+ model_name=None,
149
+ is_page_first_layout=True,
150
+ )
151
+ )
152
+ test_batch_operation(
153
+ HiCacheStorageConfig(
154
+ is_mla_model=True,
155
+ tp_rank=3,
156
+ tp_size=8,
157
+ model_name=None,
158
+ is_page_first_layout=True,
159
+ )
160
+ )
161
+ logger.info(f"✅ All tests passed")
@@ -30,6 +30,12 @@ import torch
30
30
  from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
31
31
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
32
32
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
33
+ from sglang.srt.mem_cache.radix_cache import (
34
+ RadixKey,
35
+ _key_match_page_size1,
36
+ _key_match_paged,
37
+ get_child_key,
38
+ )
33
39
 
34
40
  if TYPE_CHECKING:
35
41
  from sglang.srt.managers.schedule_batch import Req
@@ -47,7 +53,7 @@ class TreeNode:
47
53
  def __init__(self, id: Optional[int] = None):
48
54
  self.children = defaultdict(TreeNode)
49
55
  self.parent: TreeNode = None
50
- self.key: List[int] = None
56
+ self.key: RadixKey = None
51
57
  self.value: Optional[torch.Tensor] = None
52
58
  # swa_tombstone is used to indicate the kv indices have been freed for swa layers
53
59
  self.swa_tombstone = False
@@ -60,8 +66,6 @@ class TreeNode:
60
66
  self.last_access_time = time.monotonic()
61
67
 
62
68
  self.hit_count = 0
63
- # indicating the node is loading KV cache from host
64
- self.loading = False
65
69
  # store the host indices of KV cache
66
70
  self.host_value = None
67
71
 
@@ -89,27 +93,6 @@ class TreeNode:
89
93
  return self.last_access_time < other.last_access_time
90
94
 
91
95
 
92
- def _key_match_page_size1(key0: List, key1: List):
93
- i = 0
94
- for k0, k1 in zip(key0, key1):
95
- if k0 != k1:
96
- break
97
- i += 1
98
- return i
99
-
100
-
101
- def _key_match_paged(key0: List, key1: List, page_size: int):
102
- min_len = min(len(key0), len(key1))
103
-
104
- i = 0
105
- while i < min_len:
106
- if key0[i : i + page_size] != key1[i : i + page_size]:
107
- break
108
- i += page_size
109
-
110
- return i
111
-
112
-
113
96
  def gen_swa_uuid() -> int:
114
97
  TreeNode.swa_uuid_counter += 1
115
98
  return TreeNode.swa_uuid_counter
@@ -358,10 +341,10 @@ class SWARadixCache(BasePrefixCache):
358
341
 
359
342
  if self.page_size == 1:
360
343
  self.key_match_fn = _key_match_page_size1
361
- self.get_child_key_fn = lambda key: key[0]
344
+ self.get_child_key_fn = get_child_key
362
345
  else:
363
346
  self.key_match_fn = partial(_key_match_paged, page_size=page_size)
364
- self.get_child_key_fn = lambda key: tuple(key[:page_size])
347
+ self.get_child_key_fn = partial(get_child_key, page_size=page_size)
365
348
 
366
349
  self.sliding_window_size = sliding_window_size
367
350
  self.reset()
@@ -382,10 +365,10 @@ class SWARadixCache(BasePrefixCache):
382
365
  self.full_lru_list = LRUList(swa=False)
383
366
  self.swa_lru_list = LRUList(swa=True)
384
367
 
385
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
368
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
386
369
  """Find the matching prefix from the radix tree.
387
370
  Args:
388
- key: A list of token IDs to find a matching prefix.
371
+ key: A RadixKey contains token IDs to find a matching prefix.
389
372
  Returns:
390
373
  A tuple of a tensor of matching prefix token IDs and
391
374
  the last node that contains the prefix values. Note that
@@ -419,12 +402,12 @@ class SWARadixCache(BasePrefixCache):
419
402
  last_host_node=last_node,
420
403
  )
421
404
 
422
- def insert(self, key: List, value=None, prev_prefix_len: int = 0) -> int:
405
+ def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int:
423
406
  if self.disable:
424
407
  return 0
425
408
 
426
409
  if value is None:
427
- value = [x for x in key]
410
+ value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
428
411
  return self._insert_helper(self.root_node, key, value, prev_prefix_len)
429
412
 
430
413
  def cache_finished_req(self, req: Req) -> None:
@@ -455,7 +438,7 @@ class SWARadixCache(BasePrefixCache):
455
438
  # insert the token_ids and kv_indices into the radix tree
456
439
  # Note: the insert function already frees the overlapped kv_indices
457
440
  new_prefix_len = self.insert(
458
- token_ids[:page_aligned_len],
441
+ RadixKey(token_ids[:page_aligned_len], req.extra_key),
459
442
  page_aligned_kv_indices,
460
443
  len(req.prefix_indices),
461
444
  )
@@ -491,11 +474,15 @@ class SWARadixCache(BasePrefixCache):
491
474
  # Radix Cache takes one ref in memory pool
492
475
  # Note: the insert function already frees the overlapped kv_indices
493
476
  new_prefix_len = self.insert(
494
- page_aligned_token_ids, page_aligned_kv_indices, len(req.prefix_indices)
477
+ RadixKey(page_aligned_token_ids, req.extra_key),
478
+ page_aligned_kv_indices,
479
+ len(req.prefix_indices),
495
480
  )
496
481
 
497
482
  # The prefix indices could be updated, reuse it
498
- new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
483
+ new_indices, new_last_node, _, _ = self.match_prefix(
484
+ RadixKey(page_aligned_token_ids, req.extra_key)
485
+ )
499
486
  assert len(req.prefix_indices) <= len(
500
487
  new_indices
501
488
  ), f"{req.prefix_indices=}, {new_indices=}"
@@ -734,7 +721,9 @@ class SWARadixCache(BasePrefixCache):
734
721
 
735
722
  ##### Internal Helper Functions #####
736
723
 
737
- def _match_prefix_helper(self, key: List) -> Tuple[List[torch.Tensor], TreeNode]:
724
+ def _match_prefix_helper(
725
+ self, key: RadixKey
726
+ ) -> Tuple[List[torch.Tensor], TreeNode]:
738
727
  """
739
728
  SWA prefix matching helper. It factors in the sliding window size such that
740
729
  the matched node is guaranteed to either 1. connected to root without swa tombstone,
@@ -798,7 +787,7 @@ class SWARadixCache(BasePrefixCache):
798
787
 
799
788
  return value[:best_value_len], best_last_node
800
789
 
801
- def _split_node(self, key: List[int], child: TreeNode, split_len: int) -> TreeNode:
790
+ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode:
802
791
  # new_node -> child
803
792
  new_node = TreeNode()
804
793
  new_node.children = {self.get_child_key_fn(key[split_len:]): child}
@@ -833,7 +822,7 @@ class SWARadixCache(BasePrefixCache):
833
822
  return new_node
834
823
 
835
824
  def _insert_helper(
836
- self, node: TreeNode, key: List, value, update_kv_after_len: int
825
+ self, node: TreeNode, key: RadixKey, value, update_kv_after_len: int
837
826
  ) -> int:
838
827
  # Update the last access time from root to leaf, so that
839
828
  # swa will tombstone the node closer to root first