sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,480 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
11
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
12
+ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
13
+ from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
14
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
15
+ from sglang.srt.server_args import get_global_server_args
16
+ from sglang.srt.utils import support_triton
17
+
18
+ if TYPE_CHECKING:
19
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @triton.jit
25
+ def write_req_to_token_pool_triton(
26
+ req_to_token_ptr, # [max_batch, max_context_len]
27
+ req_pool_indices,
28
+ prefix_tensors,
29
+ pre_lens,
30
+ seq_lens,
31
+ extend_lens,
32
+ out_cache_loc,
33
+ req_to_token_ptr_stride: tl.constexpr,
34
+ ):
35
+ BLOCK_SIZE: tl.constexpr = 512
36
+ pid = tl.program_id(0)
37
+
38
+ req_pool_index = tl.load(req_pool_indices + pid)
39
+ pre_len = tl.load(pre_lens + pid)
40
+ seq_len = tl.load(seq_lens + pid)
41
+ prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64))
42
+
43
+ # write prefix
44
+ num_loop = tl.cdiv(pre_len, BLOCK_SIZE)
45
+ for i in range(num_loop):
46
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
47
+ mask = offset < pre_len
48
+ value = tl.load(prefix_tensor + offset, mask=mask)
49
+ tl.store(
50
+ req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset,
51
+ value,
52
+ mask=mask,
53
+ )
54
+
55
+ # NOTE: This can be slow for large bs
56
+ cumsum_start = tl.cast(0, tl.int64)
57
+ for i in range(pid):
58
+ cumsum_start += tl.load(extend_lens + i)
59
+
60
+ num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
61
+ for i in range(num_loop):
62
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
63
+ mask = offset < (seq_len - pre_len)
64
+ value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
65
+ tl.store(
66
+ req_to_token_ptr
67
+ + req_pool_index * req_to_token_ptr_stride
68
+ + offset
69
+ + pre_len,
70
+ value,
71
+ mask=mask,
72
+ )
73
+
74
+
75
+ def write_cache_indices(
76
+ out_cache_loc: torch.Tensor,
77
+ req_pool_indices_tensor: torch.Tensor,
78
+ req_pool_indices_cpu: torch.Tensor,
79
+ prefix_lens_tensor: torch.Tensor,
80
+ prefix_lens_cpu: torch.Tensor,
81
+ seq_lens_tensor: torch.Tensor,
82
+ seq_lens_cpu: torch.Tensor,
83
+ extend_lens_tensor: torch.Tensor,
84
+ extend_lens_cpu: torch.Tensor,
85
+ prefix_tensors: list[torch.Tensor],
86
+ req_to_token_pool: ReqToTokenPool,
87
+ ):
88
+ if support_triton(get_global_server_args().attention_backend):
89
+ prefix_pointers = torch.tensor(
90
+ [t.data_ptr() for t in prefix_tensors],
91
+ device=req_to_token_pool.device,
92
+ )
93
+ # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
94
+ write_req_to_token_pool_triton[(req_pool_indices_tensor.shape[0],)](
95
+ req_to_token_pool.req_to_token,
96
+ req_pool_indices_tensor,
97
+ prefix_pointers,
98
+ prefix_lens_tensor,
99
+ seq_lens_tensor,
100
+ extend_lens_tensor,
101
+ out_cache_loc,
102
+ req_to_token_pool.req_to_token.shape[1],
103
+ )
104
+ else:
105
+ pt = 0
106
+ for i in range(req_pool_indices_cpu.shape[0]):
107
+ req_idx = req_pool_indices_cpu[i].item()
108
+ prefix_len = prefix_lens_cpu[i].item()
109
+ seq_len = seq_lens_cpu[i].item()
110
+ extend_len = extend_lens_cpu[i].item()
111
+
112
+ req_to_token_pool.write(
113
+ (req_idx, slice(0, prefix_len)),
114
+ prefix_tensors[i],
115
+ )
116
+ req_to_token_pool.write(
117
+ (req_idx, slice(prefix_len, seq_len)),
118
+ out_cache_loc[pt : pt + extend_len],
119
+ )
120
+ pt += extend_len
121
+
122
+
123
+ def get_last_loc(
124
+ req_to_token: torch.Tensor,
125
+ req_pool_indices_tensor: torch.Tensor,
126
+ prefix_lens_tensor: torch.Tensor,
127
+ ) -> torch.Tensor:
128
+ if (
129
+ get_global_server_args().attention_backend != "ascend"
130
+ and get_global_server_args().attention_backend != "torch_native"
131
+ ):
132
+ impl = get_last_loc_triton
133
+ else:
134
+ impl = get_last_loc_torch
135
+
136
+ return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
137
+
138
+
139
+ def get_last_loc_torch(
140
+ req_to_token: torch.Tensor,
141
+ req_pool_indices_tensor: torch.Tensor,
142
+ prefix_lens_tensor: torch.Tensor,
143
+ ) -> torch.Tensor:
144
+ return torch.where(
145
+ prefix_lens_tensor > 0,
146
+ req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
147
+ torch.full_like(prefix_lens_tensor, -1),
148
+ )
149
+
150
+
151
+ @triton.jit
152
+ def get_last_loc_kernel(
153
+ req_to_token,
154
+ req_pool_indices_tensor,
155
+ prefix_lens_tensor,
156
+ result,
157
+ num_tokens,
158
+ req_to_token_stride,
159
+ BLOCK_SIZE: tl.constexpr,
160
+ ):
161
+ pid = tl.program_id(0)
162
+ offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
163
+ mask = offset < num_tokens
164
+
165
+ prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
166
+ req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
167
+
168
+ token_mask = prefix_lens > 0
169
+ token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
170
+ tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
171
+
172
+ tl.store(result + offset, tokens, mask=mask)
173
+
174
+
175
+ def get_last_loc_triton(
176
+ req_to_token: torch.Tensor,
177
+ req_pool_indices_tensor: torch.Tensor,
178
+ prefix_lens_tensor: torch.Tensor,
179
+ ) -> torch.Tensor:
180
+ BLOCK_SIZE = 256
181
+ num_tokens = prefix_lens_tensor.shape[0]
182
+ result = torch.empty_like(prefix_lens_tensor)
183
+ grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
184
+
185
+ get_last_loc_kernel[grid](
186
+ req_to_token,
187
+ req_pool_indices_tensor,
188
+ prefix_lens_tensor,
189
+ result,
190
+ num_tokens,
191
+ req_to_token.stride(0),
192
+ BLOCK_SIZE,
193
+ )
194
+ return result
195
+
196
+
197
+ def alloc_token_slots(
198
+ tree_cache: BasePrefixCache,
199
+ num_tokens: int,
200
+ backup_state: bool = False,
201
+ ):
202
+ allocator = tree_cache.token_to_kv_pool_allocator
203
+ evict_from_tree_cache(tree_cache, num_tokens)
204
+
205
+ state = None
206
+ if backup_state:
207
+ state = allocator.backup_state()
208
+
209
+ out_cache_loc = allocator.alloc(num_tokens)
210
+
211
+ if out_cache_loc is None:
212
+ error_msg = (
213
+ f"Out of memory. Try to lower your batch size.\n"
214
+ f"Try to allocate {num_tokens} tokens.\n"
215
+ f"{available_and_evictable_str(tree_cache)}"
216
+ )
217
+ logger.error(error_msg)
218
+ if tree_cache is not None:
219
+ tree_cache.pretty_print()
220
+ raise RuntimeError(error_msg)
221
+
222
+ return (out_cache_loc, state) if backup_state else out_cache_loc
223
+
224
+
225
+ def evict_from_tree_cache(tree_cache: BasePrefixCache | None, num_tokens: int):
226
+ if tree_cache is None:
227
+ return
228
+
229
+ if isinstance(tree_cache, (SWAChunkCache, ChunkCache)):
230
+ return
231
+
232
+ allocator = tree_cache.token_to_kv_pool_allocator
233
+
234
+ # Check if this is a hybrid allocator
235
+ if hasattr(allocator, "full_available_size"):
236
+ # Hybrid allocator
237
+ full_available_size = allocator.full_available_size()
238
+ swa_available_size = allocator.swa_available_size()
239
+
240
+ if full_available_size < num_tokens or swa_available_size < num_tokens:
241
+ full_num_tokens = max(0, num_tokens - full_available_size)
242
+ swa_num_tokens = max(0, num_tokens - swa_available_size)
243
+ tree_cache.evict(full_num_tokens, swa_num_tokens)
244
+ else:
245
+ # Standard allocator
246
+ if allocator.available_size() < num_tokens:
247
+ tree_cache.evict(num_tokens)
248
+
249
+
250
+ def alloc_paged_token_slots_extend(
251
+ tree_cache: BasePrefixCache,
252
+ prefix_lens: torch.Tensor,
253
+ prefix_lens_cpu: torch.Tensor,
254
+ seq_lens: torch.Tensor,
255
+ seq_lens_cpu: torch.Tensor,
256
+ last_loc: torch.Tensor,
257
+ extend_num_tokens: int,
258
+ backup_state: bool = False,
259
+ ):
260
+ # Over estimate the number of tokens: assume each request needs a new page.
261
+ allocator = tree_cache.token_to_kv_pool_allocator
262
+ num_tokens = extend_num_tokens + len(seq_lens_cpu) * allocator.page_size
263
+ evict_from_tree_cache(tree_cache, num_tokens)
264
+
265
+ state = None
266
+ if backup_state:
267
+ state = allocator.backup_state()
268
+
269
+ out_cache_loc = allocator.alloc_extend(
270
+ prefix_lens,
271
+ prefix_lens_cpu,
272
+ seq_lens,
273
+ seq_lens_cpu,
274
+ last_loc,
275
+ extend_num_tokens,
276
+ )
277
+
278
+ if out_cache_loc is None:
279
+ error_msg = (
280
+ f"Prefill out of memory. Try to lower your batch size.\n"
281
+ f"Try to allocate {extend_num_tokens} tokens.\n"
282
+ f"{available_and_evictable_str(tree_cache)}"
283
+ )
284
+ logger.error(error_msg)
285
+ if tree_cache is not None:
286
+ tree_cache.pretty_print()
287
+ raise RuntimeError(error_msg)
288
+
289
+ return (out_cache_loc, state) if backup_state else out_cache_loc
290
+
291
+
292
+ def alloc_req_slots(
293
+ req_to_token_pool: ReqToTokenPool,
294
+ num_reqs: int,
295
+ reqs: list[Req] | None,
296
+ tree_cache: BasePrefixCache | None,
297
+ ) -> list[int]:
298
+ """Allocate request slots from the pool."""
299
+ if isinstance(req_to_token_pool, HybridReqToTokenPool):
300
+ mamba_available_size = req_to_token_pool.mamba_pool.available_size()
301
+ if mamba_available_size < num_reqs:
302
+ if tree_cache is not None and isinstance(tree_cache, MambaRadixCache):
303
+ mamba_num = max(0, num_reqs - mamba_available_size)
304
+ tree_cache.evict_mamba(mamba_num)
305
+ req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs)
306
+ else:
307
+ req_pool_indices = req_to_token_pool.alloc(num_reqs)
308
+
309
+ if req_pool_indices is None:
310
+ raise RuntimeError(
311
+ "alloc_req_slots runs out of memory. "
312
+ "Please set a smaller number for `--max-running-requests`. "
313
+ f"{req_to_token_pool.available_size()=}, "
314
+ f"{num_reqs=}, "
315
+ )
316
+ return req_pool_indices
317
+
318
+
319
+ def alloc_for_extend(
320
+ batch: ScheduleBatch,
321
+ ) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
322
+ """
323
+ Allocate KV cache for extend batch and write to req_to_token_pool.
324
+
325
+ Returns:
326
+ out_cache_loc: allocated cache locations
327
+ req_pool_indices_device: request pool indices at a device tensor
328
+ req_pool_indices: request pool indices as list
329
+ """
330
+ # free out-of-window swa tokens
331
+ if isinstance(batch.tree_cache, SWAChunkCache):
332
+ for req, pre_len in zip(batch.reqs, batch.prefix_lens):
333
+ batch.tree_cache.evict_swa(
334
+ req, pre_len, batch.model_config.attention_chunk_size
335
+ )
336
+
337
+ bs = len(batch.reqs)
338
+ prefix_tensors = [r.prefix_indices for r in batch.reqs]
339
+
340
+ # Create tensors for allocation
341
+ prefix_lens_cpu = torch.tensor(batch.prefix_lens, dtype=torch.int64)
342
+ extend_lens_cpu = torch.tensor(batch.extend_lens, dtype=torch.int64)
343
+ prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True)
344
+ extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True)
345
+
346
+ # Allocate req slots
347
+ req_pool_indices = alloc_req_slots(
348
+ batch.req_to_token_pool, bs, batch.reqs, batch.tree_cache
349
+ )
350
+ req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64)
351
+ req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True)
352
+
353
+ # Allocate KV cache (throws exception on failure)
354
+ if batch.tree_cache.page_size == 1:
355
+ out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens)
356
+ else:
357
+ # Paged allocation - build last_loc
358
+ last_loc = [
359
+ (t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device))
360
+ for t in prefix_tensors
361
+ ]
362
+ out_cache_loc = alloc_paged_token_slots_extend(
363
+ tree_cache=batch.tree_cache,
364
+ prefix_lens=prefix_lens_device,
365
+ prefix_lens_cpu=prefix_lens_cpu,
366
+ seq_lens=batch.seq_lens,
367
+ seq_lens_cpu=batch.seq_lens_cpu,
368
+ last_loc=torch.cat(last_loc),
369
+ extend_num_tokens=batch.extend_num_tokens,
370
+ )
371
+
372
+ # Write to req_to_token_pool
373
+ write_cache_indices(
374
+ out_cache_loc,
375
+ req_pool_indices_device,
376
+ req_pool_indices_cpu,
377
+ prefix_lens_device,
378
+ prefix_lens_cpu,
379
+ batch.seq_lens,
380
+ batch.seq_lens_cpu,
381
+ extend_lens_device,
382
+ extend_lens_cpu,
383
+ prefix_tensors,
384
+ batch.req_to_token_pool,
385
+ )
386
+
387
+ return out_cache_loc, req_pool_indices_device, req_pool_indices
388
+
389
+
390
+ def alloc_paged_token_slots_decode(
391
+ tree_cache: BasePrefixCache,
392
+ seq_lens: torch.Tensor,
393
+ seq_lens_cpu: torch.Tensor,
394
+ last_loc: torch.Tensor,
395
+ token_per_req: int = 1,
396
+ ) -> torch.Tensor:
397
+ """Allocate paged KV cache for decode batch."""
398
+ allocator = tree_cache.token_to_kv_pool_allocator
399
+ # Over estimate the number of tokens: assume each request needs a new page.
400
+ num_tokens = len(seq_lens) * allocator.page_size
401
+ evict_from_tree_cache(tree_cache, num_tokens)
402
+
403
+ out_cache_loc = allocator.alloc_decode(seq_lens, seq_lens_cpu, last_loc)
404
+
405
+ if out_cache_loc is None:
406
+ error_msg = (
407
+ f"Decode out of memory. Try to lower your batch size.\n"
408
+ f"Try to allocate {len(seq_lens) * token_per_req} tokens.\n"
409
+ f"{available_and_evictable_str(tree_cache)}"
410
+ )
411
+ logger.error(error_msg)
412
+ if tree_cache is not None:
413
+ tree_cache.pretty_print()
414
+ raise RuntimeError(error_msg)
415
+
416
+ return out_cache_loc
417
+
418
+
419
+ def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor:
420
+ """
421
+ Allocate KV cache for decode batch and write to req_to_token_pool.
422
+
423
+ Returns:
424
+ out_cache_loc: allocated cache locations
425
+ """
426
+ if isinstance(batch.tree_cache, SWAChunkCache):
427
+ for req in batch.reqs:
428
+ batch.tree_cache.evict_swa(
429
+ req, req.seqlen - 1, batch.model_config.attention_chunk_size
430
+ )
431
+
432
+ bs = batch.seq_lens.shape[0]
433
+
434
+ if batch.tree_cache.page_size == 1:
435
+ # Non-paged allocation
436
+ out_cache_loc = alloc_token_slots(batch.tree_cache, bs * token_per_req)
437
+ else:
438
+ # Paged allocation
439
+ last_loc = batch.req_to_token_pool.req_to_token[
440
+ batch.req_pool_indices, batch.seq_lens - 1
441
+ ]
442
+ seq_lens_next = batch.seq_lens + token_per_req
443
+ out_cache_loc = alloc_paged_token_slots_decode(
444
+ tree_cache=batch.tree_cache,
445
+ seq_lens=seq_lens_next,
446
+ seq_lens_cpu=batch.seq_lens_cpu + token_per_req,
447
+ last_loc=last_loc,
448
+ token_per_req=token_per_req,
449
+ )
450
+
451
+ # Write to req_to_token_pool
452
+ if batch.model_config.is_encoder_decoder:
453
+ locs = batch.encoder_lens + batch.seq_lens
454
+ else:
455
+ locs = batch.seq_lens.clone()
456
+
457
+ batch.req_to_token_pool.write(
458
+ (batch.req_pool_indices, locs), out_cache_loc.to(torch.int32)
459
+ )
460
+
461
+ return out_cache_loc
462
+
463
+
464
+ def available_and_evictable_str(tree_cache) -> str:
465
+ token_to_kv_pool_allocator = tree_cache.token_to_kv_pool_allocator
466
+ if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
467
+ full_available_size = token_to_kv_pool_allocator.full_available_size()
468
+ swa_available_size = token_to_kv_pool_allocator.swa_available_size()
469
+ full_evictable_size = tree_cache.full_evictable_size()
470
+ swa_evictable_size = tree_cache.swa_evictable_size()
471
+ return (
472
+ f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
473
+ f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
474
+ f"Full LRU list evictable size: {tree_cache.full_lru_list_evictable_size()}\n"
475
+ f"SWA LRU list evictable size: {tree_cache.swa_lru_list_evictable_size()}\n"
476
+ )
477
+ else:
478
+ available_size = token_to_kv_pool_allocator.available_size()
479
+ evictable_size = tree_cache.evictable_size()
480
+ return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING, List, Tuple, Union
4
+ from typing import TYPE_CHECKING, Tuple, Union
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  from sglang.srt.mem_cache.radix_cache import TreeNode
@@ -21,3 +21,18 @@ class LRUStrategy(EvictionStrategy):
21
21
  class LFUStrategy(EvictionStrategy):
22
22
  def get_priority(self, node: "TreeNode") -> Tuple[int, float]:
23
23
  return (node.hit_count, node.last_access_time)
24
+
25
+
26
+ class FIFOStrategy(EvictionStrategy):
27
+ def get_priority(self, node: "TreeNode") -> float:
28
+ return node.creation_time
29
+
30
+
31
+ class MRUStrategy(EvictionStrategy):
32
+ def get_priority(self, node: "TreeNode") -> float:
33
+ return -node.last_access_time
34
+
35
+
36
+ class FILOStrategy(EvictionStrategy):
37
+ def get_priority(self, node: "TreeNode") -> float:
38
+ return -node.creation_time
@@ -36,6 +36,7 @@ class HiCacheStorageConfig:
36
36
 
37
37
  @dataclass
38
38
  class HiCacheStorageExtraInfo:
39
+ prefix_keys: Optional[List[str]] = (None,)
39
40
  extra_info: Optional[dict] = None
40
41
 
41
42
 
@@ -139,7 +140,9 @@ class HiCacheStorage(ABC):
139
140
  pass
140
141
 
141
142
  # TODO: Use a finer-grained return type (e.g., List[bool])
142
- def batch_exists(self, keys: List[str]) -> int:
143
+ def batch_exists(
144
+ self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None
145
+ ) -> int:
143
146
  """
144
147
  Check if the keys exist in the storage.
145
148
  return the number of consecutive existing keys from the start.
@@ -84,12 +84,14 @@ class HiRadixCache(RadixCache):
84
84
  prefetch_threshold,
85
85
  prefetch_timeout_base,
86
86
  prefetch_timeout_per_ki_token,
87
+ hicache_storage_pass_prefix_keys,
87
88
  ) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
88
89
  self.prefetch_threshold = prefetch_threshold
89
90
  self.prefetch_timeout_base = prefetch_timeout_base
90
91
  self.prefetch_timeout_per_page = (
91
92
  page_size / 1024 * prefetch_timeout_per_ki_token
92
93
  )
94
+ self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys
93
95
  # TODO: support more timeout check functions
94
96
  self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
95
97
  self.prefetch_stop_policy = hicache_storage_prefetch_policy
@@ -149,7 +151,7 @@ class HiRadixCache(RadixCache):
149
151
  storage_backend_extra_config: JSON string containing extra configuration
150
152
 
151
153
  Returns:
152
- tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
154
+ tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys)
153
155
  """
154
156
  # Parse extra config JSON if provided
155
157
  extra_config = {}
@@ -165,6 +167,9 @@ class HiRadixCache(RadixCache):
165
167
  prefetch_timeout_per_ki_token = extra_config.pop(
166
168
  "prefetch_timeout_per_ki_token", 0.25
167
169
  ) # seconds per 1024 tokens
170
+ hicache_storage_pass_prefix_keys = extra_config.pop(
171
+ "hicache_storage_pass_prefix_keys", False
172
+ )
168
173
 
169
174
  if not isinstance(prefetch_threshold, int):
170
175
  raise ValueError(
@@ -184,6 +189,7 @@ class HiRadixCache(RadixCache):
184
189
  prefetch_threshold,
185
190
  float(prefetch_timeout_base),
186
191
  float(prefetch_timeout_per_ki_token),
192
+ hicache_storage_pass_prefix_keys,
187
193
  )
188
194
 
189
195
  def reset(self):
@@ -245,8 +251,14 @@ class HiRadixCache(RadixCache):
245
251
  return len(host_indices)
246
252
 
247
253
  def write_backup_storage(self, node: TreeNode):
254
+ prefix_keys = (
255
+ node.get_prefix_hash_values(node.parent)
256
+ if self.hicache_storage_pass_prefix_keys
257
+ else None
258
+ )
259
+
248
260
  operation_id = self.cache_controller.write_storage(
249
- node.host_value, node.key, node.hash_value
261
+ node.host_value, node.key, node.hash_value, prefix_keys
250
262
  )
251
263
  self.ongoing_backup[operation_id] = node
252
264
  node.protect_host()
@@ -700,6 +712,7 @@ class HiRadixCache(RadixCache):
700
712
  last_host_node: TreeNode,
701
713
  new_input_tokens: List[int],
702
714
  last_hash: Optional[str] = None,
715
+ prefix_keys: Optional[List[str]] = None,
703
716
  ):
704
717
  # align the number of fetching tokens to the page size
705
718
  prefetch_length = len(new_input_tokens) - (
@@ -723,7 +736,7 @@ class HiRadixCache(RadixCache):
723
736
  # no sufficient host memory for prefetch
724
737
  return
725
738
  operation = self.cache_controller.prefetch(
726
- req_id, host_indices, new_input_tokens, last_hash
739
+ req_id, host_indices, new_input_tokens, last_hash, prefix_keys
727
740
  )
728
741
  self.ongoing_prefetch[req_id] = (
729
742
  last_host_node,