sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,480 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
11
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
12
+ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
13
+ from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
14
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
15
+ from sglang.srt.server_args import get_global_server_args
16
+ from sglang.srt.utils import support_triton
17
+
18
+ if TYPE_CHECKING:
19
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @triton.jit
25
+ def write_req_to_token_pool_triton(
26
+ req_to_token_ptr, # [max_batch, max_context_len]
27
+ req_pool_indices,
28
+ prefix_tensors,
29
+ pre_lens,
30
+ seq_lens,
31
+ extend_lens,
32
+ out_cache_loc,
33
+ req_to_token_ptr_stride: tl.constexpr,
34
+ ):
35
+ BLOCK_SIZE: tl.constexpr = 512
36
+ pid = tl.program_id(0)
37
+
38
+ req_pool_index = tl.load(req_pool_indices + pid)
39
+ pre_len = tl.load(pre_lens + pid)
40
+ seq_len = tl.load(seq_lens + pid)
41
+ prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64))
42
+
43
+ # write prefix
44
+ num_loop = tl.cdiv(pre_len, BLOCK_SIZE)
45
+ for i in range(num_loop):
46
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
47
+ mask = offset < pre_len
48
+ value = tl.load(prefix_tensor + offset, mask=mask)
49
+ tl.store(
50
+ req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset,
51
+ value,
52
+ mask=mask,
53
+ )
54
+
55
+ # NOTE: This can be slow for large bs
56
+ cumsum_start = tl.cast(0, tl.int64)
57
+ for i in range(pid):
58
+ cumsum_start += tl.load(extend_lens + i)
59
+
60
+ num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
61
+ for i in range(num_loop):
62
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
63
+ mask = offset < (seq_len - pre_len)
64
+ value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
65
+ tl.store(
66
+ req_to_token_ptr
67
+ + req_pool_index * req_to_token_ptr_stride
68
+ + offset
69
+ + pre_len,
70
+ value,
71
+ mask=mask,
72
+ )
73
+
74
+
75
+ def write_cache_indices(
76
+ out_cache_loc: torch.Tensor,
77
+ req_pool_indices_tensor: torch.Tensor,
78
+ req_pool_indices_cpu: torch.Tensor,
79
+ prefix_lens_tensor: torch.Tensor,
80
+ prefix_lens_cpu: torch.Tensor,
81
+ seq_lens_tensor: torch.Tensor,
82
+ seq_lens_cpu: torch.Tensor,
83
+ extend_lens_tensor: torch.Tensor,
84
+ extend_lens_cpu: torch.Tensor,
85
+ prefix_tensors: list[torch.Tensor],
86
+ req_to_token_pool: ReqToTokenPool,
87
+ ):
88
+ if support_triton(get_global_server_args().attention_backend):
89
+ prefix_pointers = torch.tensor(
90
+ [t.data_ptr() for t in prefix_tensors],
91
+ device=req_to_token_pool.device,
92
+ )
93
+ # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
94
+ write_req_to_token_pool_triton[(req_pool_indices_tensor.shape[0],)](
95
+ req_to_token_pool.req_to_token,
96
+ req_pool_indices_tensor,
97
+ prefix_pointers,
98
+ prefix_lens_tensor,
99
+ seq_lens_tensor,
100
+ extend_lens_tensor,
101
+ out_cache_loc,
102
+ req_to_token_pool.req_to_token.shape[1],
103
+ )
104
+ else:
105
+ pt = 0
106
+ for i in range(req_pool_indices_cpu.shape[0]):
107
+ req_idx = req_pool_indices_cpu[i].item()
108
+ prefix_len = prefix_lens_cpu[i].item()
109
+ seq_len = seq_lens_cpu[i].item()
110
+ extend_len = extend_lens_cpu[i].item()
111
+
112
+ req_to_token_pool.write(
113
+ (req_idx, slice(0, prefix_len)),
114
+ prefix_tensors[i],
115
+ )
116
+ req_to_token_pool.write(
117
+ (req_idx, slice(prefix_len, seq_len)),
118
+ out_cache_loc[pt : pt + extend_len],
119
+ )
120
+ pt += extend_len
121
+
122
+
123
+ def get_last_loc(
124
+ req_to_token: torch.Tensor,
125
+ req_pool_indices_tensor: torch.Tensor,
126
+ prefix_lens_tensor: torch.Tensor,
127
+ ) -> torch.Tensor:
128
+ if (
129
+ get_global_server_args().attention_backend != "ascend"
130
+ and get_global_server_args().attention_backend != "torch_native"
131
+ ):
132
+ impl = get_last_loc_triton
133
+ else:
134
+ impl = get_last_loc_torch
135
+
136
+ return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
137
+
138
+
139
+ def get_last_loc_torch(
140
+ req_to_token: torch.Tensor,
141
+ req_pool_indices_tensor: torch.Tensor,
142
+ prefix_lens_tensor: torch.Tensor,
143
+ ) -> torch.Tensor:
144
+ return torch.where(
145
+ prefix_lens_tensor > 0,
146
+ req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
147
+ torch.full_like(prefix_lens_tensor, -1),
148
+ )
149
+
150
+
151
+ @triton.jit
152
+ def get_last_loc_kernel(
153
+ req_to_token,
154
+ req_pool_indices_tensor,
155
+ prefix_lens_tensor,
156
+ result,
157
+ num_tokens,
158
+ req_to_token_stride,
159
+ BLOCK_SIZE: tl.constexpr,
160
+ ):
161
+ pid = tl.program_id(0)
162
+ offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
163
+ mask = offset < num_tokens
164
+
165
+ prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
166
+ req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
167
+
168
+ token_mask = prefix_lens > 0
169
+ token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
170
+ tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
171
+
172
+ tl.store(result + offset, tokens, mask=mask)
173
+
174
+
175
+ def get_last_loc_triton(
176
+ req_to_token: torch.Tensor,
177
+ req_pool_indices_tensor: torch.Tensor,
178
+ prefix_lens_tensor: torch.Tensor,
179
+ ) -> torch.Tensor:
180
+ BLOCK_SIZE = 256
181
+ num_tokens = prefix_lens_tensor.shape[0]
182
+ result = torch.empty_like(prefix_lens_tensor)
183
+ grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
184
+
185
+ get_last_loc_kernel[grid](
186
+ req_to_token,
187
+ req_pool_indices_tensor,
188
+ prefix_lens_tensor,
189
+ result,
190
+ num_tokens,
191
+ req_to_token.stride(0),
192
+ BLOCK_SIZE,
193
+ )
194
+ return result
195
+
196
+
197
+ def alloc_token_slots(
198
+ tree_cache: BasePrefixCache,
199
+ num_tokens: int,
200
+ backup_state: bool = False,
201
+ ):
202
+ allocator = tree_cache.token_to_kv_pool_allocator
203
+ evict_from_tree_cache(tree_cache, num_tokens)
204
+
205
+ state = None
206
+ if backup_state:
207
+ state = allocator.backup_state()
208
+
209
+ out_cache_loc = allocator.alloc(num_tokens)
210
+
211
+ if out_cache_loc is None:
212
+ error_msg = (
213
+ f"Out of memory. Try to lower your batch size.\n"
214
+ f"Try to allocate {num_tokens} tokens.\n"
215
+ f"{available_and_evictable_str(tree_cache)}"
216
+ )
217
+ logger.error(error_msg)
218
+ if tree_cache is not None:
219
+ tree_cache.pretty_print()
220
+ raise RuntimeError(error_msg)
221
+
222
+ return (out_cache_loc, state) if backup_state else out_cache_loc
223
+
224
+
225
+ def evict_from_tree_cache(tree_cache: BasePrefixCache | None, num_tokens: int):
226
+ if tree_cache is None:
227
+ return
228
+
229
+ if isinstance(tree_cache, (SWAChunkCache, ChunkCache)):
230
+ return
231
+
232
+ allocator = tree_cache.token_to_kv_pool_allocator
233
+
234
+ # Check if this is a hybrid allocator
235
+ if hasattr(allocator, "full_available_size"):
236
+ # Hybrid allocator
237
+ full_available_size = allocator.full_available_size()
238
+ swa_available_size = allocator.swa_available_size()
239
+
240
+ if full_available_size < num_tokens or swa_available_size < num_tokens:
241
+ full_num_tokens = max(0, num_tokens - full_available_size)
242
+ swa_num_tokens = max(0, num_tokens - swa_available_size)
243
+ tree_cache.evict(full_num_tokens, swa_num_tokens)
244
+ else:
245
+ # Standard allocator
246
+ if allocator.available_size() < num_tokens:
247
+ tree_cache.evict(num_tokens)
248
+
249
+
250
+ def alloc_paged_token_slots_extend(
251
+ tree_cache: BasePrefixCache,
252
+ prefix_lens: torch.Tensor,
253
+ prefix_lens_cpu: torch.Tensor,
254
+ seq_lens: torch.Tensor,
255
+ seq_lens_cpu: torch.Tensor,
256
+ last_loc: torch.Tensor,
257
+ extend_num_tokens: int,
258
+ backup_state: bool = False,
259
+ ):
260
+ # Over estimate the number of tokens: assume each request needs a new page.
261
+ allocator = tree_cache.token_to_kv_pool_allocator
262
+ num_tokens = extend_num_tokens + len(seq_lens_cpu) * allocator.page_size
263
+ evict_from_tree_cache(tree_cache, num_tokens)
264
+
265
+ state = None
266
+ if backup_state:
267
+ state = allocator.backup_state()
268
+
269
+ out_cache_loc = allocator.alloc_extend(
270
+ prefix_lens,
271
+ prefix_lens_cpu,
272
+ seq_lens,
273
+ seq_lens_cpu,
274
+ last_loc,
275
+ extend_num_tokens,
276
+ )
277
+
278
+ if out_cache_loc is None:
279
+ error_msg = (
280
+ f"Prefill out of memory. Try to lower your batch size.\n"
281
+ f"Try to allocate {extend_num_tokens} tokens.\n"
282
+ f"{available_and_evictable_str(tree_cache)}"
283
+ )
284
+ logger.error(error_msg)
285
+ if tree_cache is not None:
286
+ tree_cache.pretty_print()
287
+ raise RuntimeError(error_msg)
288
+
289
+ return (out_cache_loc, state) if backup_state else out_cache_loc
290
+
291
+
292
+ def alloc_req_slots(
293
+ req_to_token_pool: ReqToTokenPool,
294
+ num_reqs: int,
295
+ reqs: list[Req] | None,
296
+ tree_cache: BasePrefixCache | None,
297
+ ) -> list[int]:
298
+ """Allocate request slots from the pool."""
299
+ if isinstance(req_to_token_pool, HybridReqToTokenPool):
300
+ mamba_available_size = req_to_token_pool.mamba_pool.available_size()
301
+ if mamba_available_size < num_reqs:
302
+ if tree_cache is not None and isinstance(tree_cache, MambaRadixCache):
303
+ mamba_num = max(0, num_reqs - mamba_available_size)
304
+ tree_cache.evict_mamba(mamba_num)
305
+ req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs)
306
+ else:
307
+ req_pool_indices = req_to_token_pool.alloc(num_reqs)
308
+
309
+ if req_pool_indices is None:
310
+ raise RuntimeError(
311
+ "alloc_req_slots runs out of memory. "
312
+ "Please set a smaller number for `--max-running-requests`. "
313
+ f"{req_to_token_pool.available_size()=}, "
314
+ f"{num_reqs=}, "
315
+ )
316
+ return req_pool_indices
317
+
318
+
319
+ def alloc_for_extend(
320
+ batch: ScheduleBatch,
321
+ ) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
322
+ """
323
+ Allocate KV cache for extend batch and write to req_to_token_pool.
324
+
325
+ Returns:
326
+ out_cache_loc: allocated cache locations
327
+ req_pool_indices_device: request pool indices at a device tensor
328
+ req_pool_indices: request pool indices as list
329
+ """
330
+ # free out-of-window swa tokens
331
+ if isinstance(batch.tree_cache, SWAChunkCache):
332
+ for req, pre_len in zip(batch.reqs, batch.prefix_lens):
333
+ batch.tree_cache.evict_swa(
334
+ req, pre_len, batch.model_config.attention_chunk_size
335
+ )
336
+
337
+ bs = len(batch.reqs)
338
+ prefix_tensors = [r.prefix_indices for r in batch.reqs]
339
+
340
+ # Create tensors for allocation
341
+ prefix_lens_cpu = torch.tensor(batch.prefix_lens, dtype=torch.int64)
342
+ extend_lens_cpu = torch.tensor(batch.extend_lens, dtype=torch.int64)
343
+ prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True)
344
+ extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True)
345
+
346
+ # Allocate req slots
347
+ req_pool_indices = alloc_req_slots(
348
+ batch.req_to_token_pool, bs, batch.reqs, batch.tree_cache
349
+ )
350
+ req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64)
351
+ req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True)
352
+
353
+ # Allocate KV cache (throws exception on failure)
354
+ if batch.tree_cache.page_size == 1:
355
+ out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens)
356
+ else:
357
+ # Paged allocation - build last_loc
358
+ last_loc = [
359
+ (t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device))
360
+ for t in prefix_tensors
361
+ ]
362
+ out_cache_loc = alloc_paged_token_slots_extend(
363
+ tree_cache=batch.tree_cache,
364
+ prefix_lens=prefix_lens_device,
365
+ prefix_lens_cpu=prefix_lens_cpu,
366
+ seq_lens=batch.seq_lens,
367
+ seq_lens_cpu=batch.seq_lens_cpu,
368
+ last_loc=torch.cat(last_loc),
369
+ extend_num_tokens=batch.extend_num_tokens,
370
+ )
371
+
372
+ # Write to req_to_token_pool
373
+ write_cache_indices(
374
+ out_cache_loc,
375
+ req_pool_indices_device,
376
+ req_pool_indices_cpu,
377
+ prefix_lens_device,
378
+ prefix_lens_cpu,
379
+ batch.seq_lens,
380
+ batch.seq_lens_cpu,
381
+ extend_lens_device,
382
+ extend_lens_cpu,
383
+ prefix_tensors,
384
+ batch.req_to_token_pool,
385
+ )
386
+
387
+ return out_cache_loc, req_pool_indices_device, req_pool_indices
388
+
389
+
390
+ def alloc_paged_token_slots_decode(
391
+ tree_cache: BasePrefixCache,
392
+ seq_lens: torch.Tensor,
393
+ seq_lens_cpu: torch.Tensor,
394
+ last_loc: torch.Tensor,
395
+ token_per_req: int = 1,
396
+ ) -> torch.Tensor:
397
+ """Allocate paged KV cache for decode batch."""
398
+ allocator = tree_cache.token_to_kv_pool_allocator
399
+ # Over estimate the number of tokens: assume each request needs a new page.
400
+ num_tokens = len(seq_lens) * allocator.page_size
401
+ evict_from_tree_cache(tree_cache, num_tokens)
402
+
403
+ out_cache_loc = allocator.alloc_decode(seq_lens, seq_lens_cpu, last_loc)
404
+
405
+ if out_cache_loc is None:
406
+ error_msg = (
407
+ f"Decode out of memory. Try to lower your batch size.\n"
408
+ f"Try to allocate {len(seq_lens) * token_per_req} tokens.\n"
409
+ f"{available_and_evictable_str(tree_cache)}"
410
+ )
411
+ logger.error(error_msg)
412
+ if tree_cache is not None:
413
+ tree_cache.pretty_print()
414
+ raise RuntimeError(error_msg)
415
+
416
+ return out_cache_loc
417
+
418
+
419
+ def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor:
420
+ """
421
+ Allocate KV cache for decode batch and write to req_to_token_pool.
422
+
423
+ Returns:
424
+ out_cache_loc: allocated cache locations
425
+ """
426
+ if isinstance(batch.tree_cache, SWAChunkCache):
427
+ for req in batch.reqs:
428
+ batch.tree_cache.evict_swa(
429
+ req, req.seqlen - 1, batch.model_config.attention_chunk_size
430
+ )
431
+
432
+ bs = batch.seq_lens.shape[0]
433
+
434
+ if batch.tree_cache.page_size == 1:
435
+ # Non-paged allocation
436
+ out_cache_loc = alloc_token_slots(batch.tree_cache, bs * token_per_req)
437
+ else:
438
+ # Paged allocation
439
+ last_loc = batch.req_to_token_pool.req_to_token[
440
+ batch.req_pool_indices, batch.seq_lens - 1
441
+ ]
442
+ seq_lens_next = batch.seq_lens + token_per_req
443
+ out_cache_loc = alloc_paged_token_slots_decode(
444
+ tree_cache=batch.tree_cache,
445
+ seq_lens=seq_lens_next,
446
+ seq_lens_cpu=batch.seq_lens_cpu + token_per_req,
447
+ last_loc=last_loc,
448
+ token_per_req=token_per_req,
449
+ )
450
+
451
+ # Write to req_to_token_pool
452
+ if batch.model_config.is_encoder_decoder:
453
+ locs = batch.encoder_lens + batch.seq_lens
454
+ else:
455
+ locs = batch.seq_lens.clone()
456
+
457
+ batch.req_to_token_pool.write(
458
+ (batch.req_pool_indices, locs), out_cache_loc.to(torch.int32)
459
+ )
460
+
461
+ return out_cache_loc
462
+
463
+
464
+ def available_and_evictable_str(tree_cache) -> str:
465
+ token_to_kv_pool_allocator = tree_cache.token_to_kv_pool_allocator
466
+ if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
467
+ full_available_size = token_to_kv_pool_allocator.full_available_size()
468
+ swa_available_size = token_to_kv_pool_allocator.swa_available_size()
469
+ full_evictable_size = tree_cache.full_evictable_size()
470
+ swa_evictable_size = tree_cache.swa_evictable_size()
471
+ return (
472
+ f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
473
+ f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
474
+ f"Full LRU list evictable size: {tree_cache.full_lru_list_evictable_size()}\n"
475
+ f"SWA LRU list evictable size: {tree_cache.swa_lru_list_evictable_size()}\n"
476
+ )
477
+ else:
478
+ available_size = token_to_kv_pool_allocator.available_size()
479
+ evictable_size = tree_cache.evictable_size()
480
+ return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING, Tuple, Union
5
+
6
+ if TYPE_CHECKING:
7
+ from sglang.srt.mem_cache.radix_cache import TreeNode
8
+
9
+
10
+ class EvictionStrategy(ABC):
11
+ @abstractmethod
12
+ def get_priority(self, node: "TreeNode") -> Union[float, Tuple]:
13
+ pass
14
+
15
+
16
+ class LRUStrategy(EvictionStrategy):
17
+ def get_priority(self, node: "TreeNode") -> float:
18
+ return node.last_access_time
19
+
20
+
21
+ class LFUStrategy(EvictionStrategy):
22
+ def get_priority(self, node: "TreeNode") -> Tuple[int, float]:
23
+ return (node.hit_count, node.last_access_time)
24
+
25
+
26
+ class FIFOStrategy(EvictionStrategy):
27
+ def get_priority(self, node: "TreeNode") -> float:
28
+ return node.creation_time
29
+
30
+
31
+ class MRUStrategy(EvictionStrategy):
32
+ def get_priority(self, node: "TreeNode") -> float:
33
+ return -node.last_access_time
34
+
35
+
36
+ class FILOStrategy(EvictionStrategy):
37
+ def get_priority(self, node: "TreeNode") -> float:
38
+ return -node.creation_time
@@ -7,6 +7,8 @@ from typing import Any, List, Optional
7
7
 
8
8
  import torch
9
9
 
10
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
11
+
10
12
  logger = logging.getLogger(__name__)
11
13
 
12
14
 
@@ -32,15 +34,47 @@ class HiCacheStorageConfig:
32
34
  extra_config: Optional[dict] = None
33
35
 
34
36
 
37
+ @dataclass
38
+ class HiCacheStorageExtraInfo:
39
+ prefix_keys: Optional[List[str]] = (None,)
40
+ extra_info: Optional[dict] = None
41
+
42
+
35
43
  class HiCacheStorage(ABC):
36
44
  """
37
45
  HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
38
46
  It abstracts the underlying storage mechanism, allowing different implementations to be used.
39
47
  """
40
48
 
41
- # todo, potentially pass model and TP configs into storage backend
42
49
  # todo, the page size of storage backend does not have to be the same as the same as host memory pool
43
50
 
51
+ def register_mem_pool_host(self, mem_pool_host: HostKVCache):
52
+ self.mem_pool_host = mem_pool_host
53
+
54
+ def batch_get_v1(
55
+ self,
56
+ keys: List[str],
57
+ host_indices: torch.Tensor,
58
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
59
+ ) -> List[bool]:
60
+ """
61
+ Retrieve values for multiple keys.
62
+ Returns a list of tensors or None for each key.
63
+ """
64
+ pass
65
+
66
+ def batch_set_v1(
67
+ self,
68
+ keys: List[str],
69
+ host_indices: torch.Tensor,
70
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
71
+ ) -> List[bool]:
72
+ """
73
+ Retrieve values for multiple keys.
74
+ Returns a list of tensors or None for each key.
75
+ """
76
+ pass
77
+
44
78
  @abstractmethod
45
79
  def get(
46
80
  self,
@@ -54,6 +88,7 @@ class HiCacheStorage(ABC):
54
88
  """
55
89
  pass
56
90
 
91
+ # TODO: Deprecate
57
92
  @abstractmethod
58
93
  def batch_get(
59
94
  self,
@@ -81,6 +116,7 @@ class HiCacheStorage(ABC):
81
116
  """
82
117
  pass
83
118
 
119
+ # TODO: Deprecate
84
120
  @abstractmethod
85
121
  def batch_set(
86
122
  self,
@@ -103,7 +139,10 @@ class HiCacheStorage(ABC):
103
139
  """
104
140
  pass
105
141
 
106
- def batch_exists(self, keys: List[str]) -> int:
142
+ # TODO: Use a finer-grained return type (e.g., List[bool])
143
+ def batch_exists(
144
+ self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None
145
+ ) -> int:
107
146
  """
108
147
  Check if the keys exist in the storage.
109
148
  return the number of consecutive existing keys from the start.
@@ -114,6 +153,9 @@ class HiCacheStorage(ABC):
114
153
  return i
115
154
  return len(keys)
116
155
 
156
+ def clear(self) -> None:
157
+ pass
158
+
117
159
  def get_stats(self):
118
160
  return None
119
161