sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
1
1
  import abc
2
2
  import logging
3
3
  import threading
4
- from enum import IntEnum
5
4
  from functools import wraps
6
5
  from typing import Optional
7
6
 
@@ -1,6 +1,5 @@
1
1
  import logging
2
2
  from collections import OrderedDict
3
- from typing import Dict
4
3
 
5
4
  import torch
6
5
 
@@ -22,8 +22,8 @@ The radix tree data structure for managing the KV cache.
22
22
  import heapq
23
23
  import time
24
24
  from collections import defaultdict
25
- from functools import partial
26
- from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
25
+ from functools import lru_cache, partial
26
+ from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
27
27
 
28
28
  import torch
29
29
 
@@ -34,7 +34,14 @@ from sglang.srt.disaggregation.kv_events import (
34
34
  )
35
35
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
36
36
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
37
- from sglang.srt.mem_cache.evict_policy import EvictionStrategy, LFUStrategy, LRUStrategy
37
+ from sglang.srt.mem_cache.evict_policy import (
38
+ EvictionStrategy,
39
+ FIFOStrategy,
40
+ FILOStrategy,
41
+ LFUStrategy,
42
+ LRUStrategy,
43
+ MRUStrategy,
44
+ )
38
45
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
39
46
 
40
47
  if TYPE_CHECKING:
@@ -76,6 +83,7 @@ class TreeNode:
76
83
  self.value: Optional[torch.Tensor] = None
77
84
  self.lock_ref = 0
78
85
  self.last_access_time = time.monotonic()
86
+ self.creation_time = time.monotonic()
79
87
 
80
88
  self.hit_count = 0
81
89
  # indicating the node is locked to protect from eviction
@@ -114,6 +122,13 @@ class TreeNode:
114
122
  return None
115
123
  return self.hash_value[-1]
116
124
 
125
+ @lru_cache(maxsize=1)
126
+ def get_prefix_hash_values(self, node: TreeNode) -> List[str]:
127
+ if node is None or node.hash_value is None:
128
+ return []
129
+
130
+ return node.get_prefix_hash_values(node.parent) + node.hash_value
131
+
117
132
  def __lt__(self, other: "TreeNode"):
118
133
  return self.last_access_time < other.last_access_time
119
134
 
@@ -209,9 +224,15 @@ class RadixCache(BasePrefixCache):
209
224
  self.eviction_strategy: EvictionStrategy = LRUStrategy()
210
225
  elif eviction_policy.lower() == "lfu":
211
226
  self.eviction_strategy: EvictionStrategy = LFUStrategy()
227
+ elif eviction_policy.lower() == "fifo":
228
+ self.eviction_strategy: EvictionStrategy = FIFOStrategy()
229
+ elif eviction_policy.lower() == "mru":
230
+ self.eviction_strategy: EvictionStrategy = MRUStrategy()
231
+ elif eviction_policy.lower() == "filo":
232
+ self.eviction_strategy: EvictionStrategy = FILOStrategy()
212
233
  else:
213
234
  raise ValueError(
214
- f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu'."
235
+ f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu', 'fifo', 'mru', 'filo'."
215
236
  )
216
237
  self.reset()
217
238
 
@@ -314,18 +335,20 @@ class RadixCache(BasePrefixCache):
314
335
 
315
336
  return self._insert_helper(self.root_node, key, value)
316
337
 
317
- def cache_finished_req(self, req: Req):
338
+ def cache_finished_req(self, req: Req, is_insert: bool = True):
318
339
  """Cache request when it finishes."""
340
+ all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
319
341
  if self.disable:
320
342
  kv_indices = self.req_to_token_pool.req_to_token[
321
- req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
343
+ req.req_pool_idx, :all_token_len
322
344
  ]
323
345
  self.token_to_kv_pool_allocator.free(kv_indices)
324
346
  self.req_to_token_pool.free(req.req_pool_idx)
325
347
  return
326
348
 
327
- token_ids = (req.origin_input_ids + req.output_ids)[:-1]
328
- all_token_len = len(token_ids)
349
+ token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
350
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
351
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
329
352
  actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
330
353
  kv_indices = self.req_to_token_pool.req_to_token[
331
354
  req.req_pool_idx, :all_token_len
@@ -336,12 +359,9 @@ class RadixCache(BasePrefixCache):
336
359
  page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
337
360
  dtype=torch.int64, copy=True
338
361
  )
339
- self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
340
362
  else:
341
363
  page_aligned_len = actual_kv_len
342
364
  page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
343
- if self.is_eagle:
344
- self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
345
365
 
346
366
  page_aligned_token_len = (
347
367
  page_aligned_len + 1 if self.is_eagle else page_aligned_len
@@ -349,15 +369,27 @@ class RadixCache(BasePrefixCache):
349
369
 
350
370
  old_prefix_len = len(req.prefix_indices)
351
371
  if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
352
- # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
372
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
373
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
353
374
  old_prefix_len -= 1
354
375
 
355
376
  # Radix Cache takes one ref in memory pool
356
- new_prefix_len = self.insert(
357
- RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
358
- page_aligned_kv_indices,
359
- )
360
- self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
377
+ if is_insert:
378
+ new_prefix_len = self.insert(
379
+ RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
380
+ page_aligned_kv_indices,
381
+ )
382
+ # Free the duplicates that were already in the tree
383
+ self.token_to_kv_pool_allocator.free(
384
+ kv_indices[old_prefix_len:new_prefix_len]
385
+ )
386
+ else:
387
+ self.token_to_kv_pool_allocator.free(
388
+ kv_indices[old_prefix_len:page_aligned_len]
389
+ )
390
+
391
+ # free the unaligned tail
392
+ self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
361
393
 
362
394
  # Remove req slot release the cache lock
363
395
  self.req_to_token_pool.free(req.req_pool_idx)
@@ -370,7 +402,8 @@ class RadixCache(BasePrefixCache):
370
402
 
371
403
  token_ids = req.fill_ids
372
404
  all_token_len = len(token_ids)
373
- # The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
405
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
406
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
374
407
  actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
375
408
  kv_indices = self.req_to_token_pool.req_to_token[
376
409
  req.req_pool_idx, :all_token_len
@@ -393,7 +426,8 @@ class RadixCache(BasePrefixCache):
393
426
 
394
427
  old_prefix_len = len(req.prefix_indices)
395
428
  if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
396
- # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
429
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
430
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
397
431
  old_prefix_len -= 1
398
432
 
399
433
  # Radix Cache takes one ref in memory pool
@@ -151,32 +151,37 @@ class RadixCacheCpp(BasePrefixCache):
151
151
  def total_size(self):
152
152
  return self.tree.total_size()
153
153
 
154
- def cache_finished_req(self, req: Req):
154
+ def cache_finished_req(self, req: Req, is_insert: bool = True):
155
155
  """Cache request when it finishes."""
156
156
  assert req.req_pool_idx is not None
157
- token_ids = (req.origin_input_ids + req.output_ids)[:-1]
157
+ all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
158
+ token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
158
159
  overall_len = len(token_ids) # prefill + decode
159
160
  kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
160
161
 
161
162
  # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
162
163
  # it will automatically align them, but length of them should be equal
163
164
  old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
164
- new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
165
+ page_aligned_overall_len = overall_len // self.page_size * self.page_size
165
166
 
166
- # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
167
- assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
168
-
169
- # KVCache between old & new is newly generated, but already exists in the pool
170
- # we need to free this newly generated kv indices
171
- if old_prefix_len < new_prefix_len:
172
- self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
167
+ if is_insert:
168
+ new_prefix_len = self._insert(
169
+ RadixKey(token_ids, req.extra_key), kv_indices
170
+ )
171
+ # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
172
+ assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
173
+ # Free duplicates that were already in the pool
174
+ if old_prefix_len < new_prefix_len:
175
+ self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
176
+ else:
177
+ self.token_to_kv_pool.free(
178
+ kv_indices[old_prefix_len:page_aligned_overall_len]
179
+ )
173
180
 
174
181
  # need to free the unaligned part, since it cannot be inserted into the radix tree
175
- if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1
176
- (unaligned_len := overall_len % self.page_size) > 0
177
- ):
182
+ if page_aligned_overall_len < overall_len:
178
183
  # NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
179
- self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :])
184
+ self.token_to_kv_pool.free(kv_indices[page_aligned_overall_len:])
180
185
 
181
186
  # Remove req slot release the cache lock
182
187
  self.dec_lock_ref(req.last_node)
@@ -13,7 +13,11 @@ from aibrix_kvcache import (
13
13
  )
14
14
  from aibrix_kvcache.common.absl_logging import log_every_n_seconds
15
15
 
16
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
16
+ from sglang.srt.mem_cache.hicache_storage import (
17
+ HiCacheStorage,
18
+ HiCacheStorageConfig,
19
+ HiCacheStorageExtraInfo,
20
+ )
17
21
  from sglang.srt.mem_cache.memory_pool_host import HostKVCache
18
22
 
19
23
  logger = logging.getLogger(__name__)
@@ -140,7 +144,9 @@ class AibrixKVCacheStorage(HiCacheStorage):
140
144
  ) -> bool:
141
145
  return self.batch_set([key], [value], [target_location], [target_size])
142
146
 
143
- def batch_exists(self, keys: List[str]) -> int:
147
+ def batch_exists(
148
+ self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None
149
+ ) -> int:
144
150
  block_hash = BlockHashes(keys, self.page_size)
145
151
  status = self.kv_cache_manager.exists(None, block_hash)
146
152
  if status.is_ok():
@@ -3,20 +3,8 @@ import os
3
3
 
4
4
  import torch
5
5
  import torch.distributed
6
- from aibrix_kvcache import (
7
- BaseKVCacheManager,
8
- GroupAwareKVCacheManager,
9
- KVCacheBlockLayout,
10
- KVCacheBlockSpec,
11
- KVCacheConfig,
12
- KVCacheMetrics,
13
- KVCacheTensorSpec,
14
- ModelSpec,
15
- TokenListView,
16
- )
17
- from aibrix_kvcache.common.absl_logging import getLogger, log_every_n_seconds, log_if
6
+ from aibrix_kvcache.common.absl_logging import log_every_n_seconds
18
7
  from aibrix_kvcache_storage import AibrixKVCacheStorage
19
- from torch.distributed import Backend, ProcessGroup
20
8
 
21
9
  from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
22
10
  from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
@@ -161,7 +161,7 @@ class StorageBackendFactory:
161
161
  if backend_name == "file":
162
162
  return backend_class(storage_config)
163
163
  elif backend_name == "nixl":
164
- return backend_class()
164
+ return backend_class(storage_config)
165
165
  elif backend_name == "mooncake":
166
166
  backend = backend_class(storage_config)
167
167
  return backend
@@ -170,7 +170,7 @@ class StorageBackendFactory:
170
170
  return backend
171
171
  elif backend_name == "hf3fs":
172
172
  # Calculate bytes_per_page based on memory pool layout
173
- if mem_pool_host.layout == "page_first":
173
+ if mem_pool_host.layout in ["page_first", "page_first_direct"]:
174
174
  bytes_per_page = (
175
175
  mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
176
176
  )
@@ -2,21 +2,18 @@ import json
2
2
  import logging
3
3
  import os
4
4
  import time
5
- import uuid
6
- from dataclasses import dataclass
7
- from typing import Any, Dict, List, Optional, Tuple
5
+ from typing import Any, List, Optional, Tuple
8
6
 
9
7
  import eic
10
8
  import torch
11
9
  import yaml
12
10
 
13
- from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
14
11
  from sglang.srt.mem_cache.hicache_storage import (
15
12
  HiCacheStorage,
16
13
  HiCacheStorageConfig,
17
14
  HiCacheStorageExtraInfo,
18
15
  )
19
- from sglang.srt.mem_cache.memory_pool_host import HostKVCache, MLATokenToKVPoolHost
16
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
20
17
 
21
18
  logger = logging.getLogger(__name__)
22
19
 
@@ -408,7 +405,9 @@ class EICStorage(HiCacheStorage):
408
405
  exist_num = self.batch_exists([key])
409
406
  return exist_num == 1
410
407
 
411
- def batch_exists(self, keys) -> int:
408
+ def batch_exists(
409
+ self, keys, extra_info: Optional[HiCacheStorageExtraInfo] = None
410
+ ) -> int:
412
411
  if len(keys) == 0:
413
412
  return 0
414
413
  if self.use_zero_copy and not self.is_mla_model:
@@ -1,6 +1,5 @@
1
1
  import logging
2
2
  import os
3
- import threading
4
3
  from abc import ABC, abstractmethod
5
4
  from typing import List
6
5
 
@@ -3,8 +3,9 @@ import atexit
3
3
  import json
4
4
  import logging
5
5
  import threading
6
+ from collections import OrderedDict
6
7
  from pathlib import Path
7
- from typing import Dict, List, Optional, OrderedDict, Tuple
8
+ from typing import Dict, List, Optional, Tuple
8
9
 
9
10
  import orjson
10
11
  import requests
@@ -136,7 +137,7 @@ class GlobalMetadataState:
136
137
  num_pages = data["num_pages"]
137
138
  rank_meta = RankMetadata(num_pages)
138
139
  rank_meta.free_pages = data["free_pages"]
139
- rank_meta.key_to_index = dict(data["key_to_index"])
140
+ rank_meta.key_to_index = OrderedDict(data["key_to_index"])
140
141
  self.ranks[rank_id] = rank_meta
141
142
  logging.info(
142
143
  f"Successfully loaded metadata for {len(self.ranks)} ranks."
@@ -454,7 +454,9 @@ class HiCacheHF3FS(HiCacheStorage):
454
454
  result = self.metadata_client.exists(self.rank, [key])
455
455
  return result[0] if result else False
456
456
 
457
- def batch_exists(self, keys: List[str]) -> int:
457
+ def batch_exists(
458
+ self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None
459
+ ) -> int:
458
460
  factor = 1
459
461
  if self.is_zero_copy and not self.is_mla_model:
460
462
  keys = self._get_mha_zero_copy_keys(keys)
@@ -499,8 +501,12 @@ class HiCacheHF3FS(HiCacheStorage):
499
501
 
500
502
  def register_mem_pool_host(self, mem_pool_host: HostKVCache):
501
503
  super().register_mem_pool_host(mem_pool_host)
502
- self.is_zero_copy = self.mem_pool_host.layout == "page_first"
503
- logger.info(f"{self.is_zero_copy=}")
504
+ self.is_zero_copy = self.mem_pool_host.layout in [
505
+ "page_first",
506
+ "page_first_direct",
507
+ ]
508
+
509
+ logger.info(f"{self.is_zero_copy=}, layout={self.mem_pool_host.layout}")
504
510
 
505
511
  def _get_mha_zero_copy_keys(self, keys: List[str]) -> List[str]:
506
512
  _keys = []
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import threading
5
- from typing import TYPE_CHECKING, List, Optional
5
+ from typing import TYPE_CHECKING, Optional
6
6
 
7
7
  import torch
8
8
 
@@ -217,10 +217,12 @@ class LMCRadixCache(RadixCache):
217
217
 
218
218
  return base_res
219
219
 
220
- def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
220
+ def cache_finished_req(self, req: "Req", is_insert: bool = True) -> None: # type: ignore[override]
221
221
  """On request completion, insert device KV into radix and store to LMCache."""
222
222
 
223
- super().cache_finished_req(req)
223
+ super().cache_finished_req(req, is_insert=is_insert)
224
+ if not is_insert:
225
+ return
224
226
 
225
227
  token_ids = (req.origin_input_ids + req.output_ids)[:-1]
226
228
  kv_indices = self.req_to_token_pool.req_to_token[
@@ -1,10 +1,12 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
+ import time
4
5
  import uuid
5
6
  from dataclasses import dataclass
6
7
  from typing import Any, List, Optional
7
8
 
9
+ import requests
8
10
  import torch
9
11
 
10
12
  from sglang.srt.mem_cache.hicache_storage import (
@@ -17,9 +19,29 @@ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
17
19
  DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
18
20
  DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
19
21
  DEFAULT_MOONCAKE_CONFIG_PATH_ENV = "SGLANG_HICACHE_MOONCAKE_CONFIG_PATH"
22
+ SETUP_TIMEOUT = 600 # 10min
23
+ DEFAULT_MASTER_METRICS_PORT = 9003
24
+ DEFAULT_CHECK_SERVER = False
25
+
20
26
  logger = logging.getLogger(__name__)
21
27
 
22
28
 
29
+ def _parse_global_segment_size(value) -> int:
30
+ if isinstance(value, int):
31
+ return value
32
+ if isinstance(value, str):
33
+ s = value.strip().lower()
34
+ if s.endswith("gb"):
35
+ num = s[:-2].strip()
36
+ if not num:
37
+ raise ValueError(
38
+ "Invalid global_segment_size: missing number before 'gb'"
39
+ )
40
+ return int(num) * 1024 * 1024 * 1024
41
+ return int(s)
42
+ return int(value)
43
+
44
+
23
45
  @dataclass
24
46
  class MooncakeStoreConfig:
25
47
  local_hostname: str
@@ -29,6 +51,8 @@ class MooncakeStoreConfig:
29
51
  protocol: str
30
52
  device_name: str
31
53
  master_server_address: str
54
+ master_metrics_port: int
55
+ check_server: bool
32
56
 
33
57
  @staticmethod
34
58
  def from_file() -> "MooncakeStoreConfig":
@@ -43,14 +67,18 @@ class MooncakeStoreConfig:
43
67
  return MooncakeStoreConfig(
44
68
  local_hostname=config.get("local_hostname"),
45
69
  metadata_server=config.get("metadata_server"),
46
- global_segment_size=config.get(
47
- "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
70
+ global_segment_size=_parse_global_segment_size(
71
+ config.get("global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE)
48
72
  ),
49
73
  # Zero copy interface does not need local buffer
50
74
  local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
51
75
  protocol=config.get("protocol", "tcp"),
52
- device_name=config.get("device_name", "auto"),
76
+ device_name=config.get("device_name", ""),
53
77
  master_server_address=config.get("master_server_address"),
78
+ master_metrics_port=config.get(
79
+ "master_metrics_port", DEFAULT_MASTER_METRICS_PORT
80
+ ),
81
+ check_server=config.get("check_server", DEFAULT_CHECK_SERVER),
54
82
  )
55
83
 
56
84
  @staticmethod
@@ -58,7 +86,7 @@ class MooncakeStoreConfig:
58
86
  """Load config from a file specified in the environment variable.
59
87
  export MOONCAKE_MASTER=10.13.3.232:50051
60
88
  export MOONCAKE_PROTOCOL="rdma"
61
- export MOONCAKE_DEVICE="auto"
89
+ export MOONCAKE_DEVICE=""
62
90
  export MOONCAKE_TE_META_DATA_SERVER="P2PHANDSHAKE"
63
91
  """
64
92
  # other required environment variables...
@@ -67,14 +95,18 @@ class MooncakeStoreConfig:
67
95
  return MooncakeStoreConfig(
68
96
  local_hostname=os.getenv("LOCAL_HOSTNAME", "localhost"),
69
97
  metadata_server=os.getenv("MOONCAKE_TE_META_DATA_SERVER", "P2PHANDSHAKE"),
70
- global_segment_size=int(
98
+ global_segment_size=_parse_global_segment_size(
71
99
  os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
72
100
  ),
73
101
  # Zero copy interface does not need local buffer
74
102
  local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
75
103
  protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
76
- device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
104
+ device_name=os.getenv("MOONCAKE_DEVICE", ""),
77
105
  master_server_address=os.getenv("MOONCAKE_MASTER"),
106
+ master_metrics_port=int(
107
+ os.getenv("MOONCAKE_MASTER_METRICS_PORT", DEFAULT_GLOBAL_SEGMENT_SIZE)
108
+ ),
109
+ check_server=bool(os.getenv("MOONCAKE_CHECK_SERVER", DEFAULT_CHECK_SERVER)),
78
110
  )
79
111
 
80
112
  @staticmethod
@@ -86,24 +118,21 @@ class MooncakeStoreConfig:
86
118
  return MooncakeStoreConfig(
87
119
  local_hostname=extra_config.get("local_hostname", "localhost"),
88
120
  metadata_server=extra_config.get("metadata_server", "P2PHANDSHAKE"),
89
- global_segment_size=extra_config.get(
90
- "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
121
+ global_segment_size=_parse_global_segment_size(
122
+ extra_config.get("global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE)
91
123
  ),
92
124
  local_buffer_size=extra_config.get(
93
125
  "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
94
126
  ),
95
127
  protocol=extra_config.get("protocol", "tcp"),
96
- device_name=extra_config.get("device_name", "auto"),
128
+ device_name=extra_config.get("device_name", ""),
97
129
  master_server_address=extra_config["master_server_address"],
130
+ master_metrics_port=extra_config.get(
131
+ "master_metrics_port", DEFAULT_MASTER_METRICS_PORT
132
+ ),
133
+ check_server=extra_config.get("check_server", DEFAULT_CHECK_SERVER),
98
134
  )
99
135
 
100
- def __post_init__(self):
101
- if self.device_name == "auto":
102
- os.environ["MC_MS_AUTO_DISC"] = "1"
103
- os.environ["MC_MS_FILTERS"] = (
104
- "mlx5_bond_0, mlx5_bond_1, mlx5_bond_2, mlx5_bond_3"
105
- )
106
-
107
136
 
108
137
  class MooncakeStore(HiCacheStorage):
109
138
 
@@ -151,6 +180,16 @@ class MooncakeStore(HiCacheStorage):
151
180
  )
152
181
  per_tp_local_buffer_size = self.config.local_buffer_size // tp_scale_factor
153
182
 
183
+ # Check if extra_backend_tag should be passed to MooncakeDistributedStore
184
+ self.extra_backend_tag = None
185
+ if extra_config and "extra_backend_tag" in extra_config:
186
+ self.extra_backend_tag = extra_config["extra_backend_tag"]
187
+ logger.info(f"Using extra_backend_tag: {self.extra_backend_tag}")
188
+
189
+ # Check server status
190
+ if self.config.check_server:
191
+ self.check_server()
192
+
154
193
  ret_code = self.store.setup(
155
194
  self.config.local_hostname,
156
195
  self.config.metadata_server,
@@ -181,6 +220,39 @@ class MooncakeStore(HiCacheStorage):
181
220
  logger.error("An error occurred while loading the configuration: %s", exc)
182
221
  raise
183
222
 
223
+ def check_server(self):
224
+ master_server_ip = self.config.master_server_address.split(":")[0]
225
+ segments_url = f"http://{master_server_ip}:{self.config.master_metrics_port}/get_all_segments"
226
+ start_time = time.perf_counter()
227
+
228
+ check_result = False
229
+ while time.perf_counter() - start_time < SETUP_TIMEOUT:
230
+ try:
231
+ check_segments_resp = requests.get(segments_url, timeout=3)
232
+ except Exception:
233
+ logger.info(
234
+ "waiting mooncake store server started, cost_time: %.2f seconds.",
235
+ time.perf_counter() - start_time,
236
+ )
237
+ time.sleep(3)
238
+ continue
239
+
240
+ if check_segments_resp.text == "":
241
+ logger.info(
242
+ "waiting mooncake store server started, cost_time: %.2f seconds.",
243
+ time.perf_counter() - start_time,
244
+ )
245
+ time.sleep(3)
246
+ continue
247
+
248
+ logger.info("Mooncake store server started successfully.")
249
+ check_result = True
250
+ break
251
+
252
+ if not check_result:
253
+ logger.error("Launch mooncake store server timeout")
254
+ raise ValueError("Launch mooncake store server timeout")
255
+
184
256
  def warmup(self):
185
257
  warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
186
258
  warmup_value = bytes(4 * 1024) # 4 KB
@@ -257,6 +329,11 @@ class MooncakeStore(HiCacheStorage):
257
329
  host_indices: torch.Tensor,
258
330
  extra_info: Optional[HiCacheStorageExtraInfo] = None,
259
331
  ) -> List[bool]:
332
+ # Apply extra_backend_tag prefix if available
333
+ if self.extra_backend_tag is not None:
334
+ prefix = self.extra_backend_tag
335
+ keys = [f"{prefix}_{key}" for key in keys]
336
+
260
337
  key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
261
338
  get_results = self._get_batch_zero_copy_impl(
262
339
  key_strs, buffer_ptrs, buffer_sizes
@@ -269,6 +346,11 @@ class MooncakeStore(HiCacheStorage):
269
346
  host_indices: torch.Tensor,
270
347
  extra_info: Optional[HiCacheStorageExtraInfo] = None,
271
348
  ) -> List[bool]:
349
+ # Apply extra_backend_tag prefix if available
350
+ if self.extra_backend_tag is not None:
351
+ prefix = self.extra_backend_tag
352
+ keys = [f"{prefix}_{key}" for key in keys]
353
+
272
354
  key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
273
355
  exist_result = self._batch_exist(key_strs)
274
356
 
@@ -399,7 +481,9 @@ class MooncakeStore(HiCacheStorage):
399
481
  exist_result = self._batch_exist([key])
400
482
  return exist_result[0] == 1
401
483
 
402
- def batch_exists(self, keys) -> int:
484
+ def batch_exists(
485
+ self, keys, extra_info: Optional[HiCacheStorageExtraInfo] = None
486
+ ) -> int:
403
487
  if self.is_mla_backend:
404
488
  query_keys = [f"{key}_k" for key in keys]
405
489
  key_multiplier = 1