sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
13
13
  TreeNodeCpp,
14
14
  )
15
15
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
16
+ from sglang.srt.mem_cache.radix_cache import RadixKey
16
17
 
17
18
  if TYPE_CHECKING:
18
19
  from sglang.srt.managers.schedule_batch import Req
@@ -93,9 +94,9 @@ class RadixCacheCpp(BasePrefixCache):
93
94
  raise NotImplementedError("Host cache is not supported yet")
94
95
  self.tree.reset()
95
96
 
96
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
97
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
97
98
  device_indices_vec, host_indices_length, node_gpu, node_cpu = (
98
- self.tree.match_prefix(key)
99
+ self.tree.match_prefix(key.token_ids)
99
100
  )
100
101
  return MatchResult(
101
102
  device_indices=self._merge_tensor(device_indices_vec),
@@ -104,16 +105,16 @@ class RadixCacheCpp(BasePrefixCache):
104
105
  host_hit_length=host_indices_length,
105
106
  )
106
107
 
107
- def _insert(self, key: List[int], value: torch.Tensor) -> int:
108
+ def _insert(self, key: RadixKey, value: torch.Tensor) -> int:
108
109
  """
109
110
  Insert a key-value pair into the radix tree.
110
111
  Args:
111
- key (List[int]): The key to insert, represented as a list of integers.
112
+ key (RadixKey): The key to insert, represented as a RadixKey.
112
113
  value (torch.Tensor): The value to associate with the key.
113
114
  Returns:
114
115
  int: Number of device indices that were already present in the tree before the insertion.
115
116
  """
116
- ongoing_write, length = self.tree.writing_through(key, value)
117
+ ongoing_write, length = self.tree.writing_through(key.token_ids, value)
117
118
  if self.cache_controller is None:
118
119
  assert len(ongoing_write) == 0, "Implementation error"
119
120
  return length
@@ -150,32 +151,37 @@ class RadixCacheCpp(BasePrefixCache):
150
151
  def total_size(self):
151
152
  return self.tree.total_size()
152
153
 
153
- def cache_finished_req(self, req: Req):
154
+ def cache_finished_req(self, req: Req, is_insert: bool = True):
154
155
  """Cache request when it finishes."""
155
156
  assert req.req_pool_idx is not None
156
- token_ids = (req.origin_input_ids + req.output_ids)[:-1]
157
+ all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
158
+ token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
157
159
  overall_len = len(token_ids) # prefill + decode
158
160
  kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
159
161
 
160
162
  # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
161
163
  # it will automatically align them, but length of them should be equal
162
164
  old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
163
- new_prefix_len = self._insert(token_ids, kv_indices)
165
+ page_aligned_overall_len = overall_len // self.page_size * self.page_size
164
166
 
165
- # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
166
- assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
167
-
168
- # KVCache between old & new is newly generated, but already exists in the pool
169
- # we need to free this newly generated kv indices
170
- if old_prefix_len < new_prefix_len:
171
- self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
167
+ if is_insert:
168
+ new_prefix_len = self._insert(
169
+ RadixKey(token_ids, req.extra_key), kv_indices
170
+ )
171
+ # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
172
+ assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
173
+ # Free duplicates that were already in the pool
174
+ if old_prefix_len < new_prefix_len:
175
+ self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
176
+ else:
177
+ self.token_to_kv_pool.free(
178
+ kv_indices[old_prefix_len:page_aligned_overall_len]
179
+ )
172
180
 
173
181
  # need to free the unaligned part, since it cannot be inserted into the radix tree
174
- if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1
175
- (unaligned_len := overall_len % self.page_size) > 0
176
- ):
182
+ if page_aligned_overall_len < overall_len:
177
183
  # NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
178
- self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :])
184
+ self.token_to_kv_pool.free(kv_indices[page_aligned_overall_len:])
179
185
 
180
186
  # Remove req slot release the cache lock
181
187
  self.dec_lock_ref(req.last_node)
@@ -191,14 +197,16 @@ class RadixCacheCpp(BasePrefixCache):
191
197
  # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
192
198
  # it will automatically align them, but length of them should be equal
193
199
  old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
194
- new_prefix_len = self._insert(token_ids, kv_indices)
200
+ new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
195
201
 
196
202
  # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
197
203
  assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
198
204
 
199
205
  # TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
200
206
  # The prefix indices need to updated to reuse the kv indices in the pool
201
- new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
207
+ new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(
208
+ RadixKey(token_ids, req.extra_key).token_ids
209
+ )
202
210
  new_indices = self._merge_tensor(new_indices_vec)
203
211
  assert new_prefix_len <= len(new_indices)
204
212
 
@@ -0,0 +1,10 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to SGLang project
3
+
4
+ """Storage backend module for SGLang HiCache."""
5
+
6
+ from .backend_factory import StorageBackendFactory
7
+
8
+ __all__ = [
9
+ "StorageBackendFactory",
10
+ ]
@@ -0,0 +1,157 @@
1
+ import logging
2
+ from typing import Any, List, Optional
3
+
4
+ import torch
5
+ from aibrix_kvcache import (
6
+ BaseKVCacheManager,
7
+ BlockHashes,
8
+ KVCacheBlockLayout,
9
+ KVCacheBlockSpec,
10
+ KVCacheConfig,
11
+ KVCacheTensorSpec,
12
+ ModelSpec,
13
+ )
14
+ from aibrix_kvcache.common.absl_logging import log_every_n_seconds
15
+
16
+ from sglang.srt.mem_cache.hicache_storage import (
17
+ HiCacheStorage,
18
+ HiCacheStorageConfig,
19
+ HiCacheStorageExtraInfo,
20
+ )
21
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class AibrixKVCacheStorage(HiCacheStorage):
27
+ def __init__(self, storage_config: HiCacheStorageConfig, mem_pool: HostKVCache):
28
+ if storage_config is not None:
29
+ self.is_mla_backend = storage_config.is_mla_model
30
+ self.local_rank = storage_config.tp_rank
31
+ else:
32
+ self.is_mla_backend = False
33
+ self.local_rank = 0
34
+ kv_cache = mem_pool.device_pool
35
+ self.page_size = mem_pool.page_size
36
+ self.kv_cache_dtype = kv_cache.dtype
37
+ self.layer_num = kv_cache.layer_num
38
+ self.kv_head_ids = [
39
+ self.local_rank * kv_cache.head_num + i for i in range(kv_cache.head_num)
40
+ ]
41
+ if not self.is_mla_backend:
42
+ self.layer_ids = range(
43
+ kv_cache.start_layer, kv_cache.end_layer
44
+ ) # for pipeline parallel
45
+
46
+ self.block_spec = KVCacheBlockSpec(
47
+ block_ntokens=self.page_size,
48
+ block_dtype=self.kv_cache_dtype,
49
+ block_layout=KVCacheBlockLayout(KVCacheBlockLayout.NCLD),
50
+ tensor_spec=KVCacheTensorSpec(
51
+ heads=self.kv_head_ids,
52
+ layers=self.layer_ids,
53
+ head_size=kv_cache.head_dim,
54
+ ),
55
+ )
56
+ logger.info(self.block_spec)
57
+ config = KVCacheConfig(
58
+ block_spec=self.block_spec, model_spec=ModelSpec(102400)
59
+ )
60
+ self.kv_cache_manager = BaseKVCacheManager(config)
61
+ else:
62
+ raise NotImplementedError(
63
+ "MLA is not supported by AibrixKVCacheStorage yet."
64
+ )
65
+
66
+ def _aibrix_kvcache_metrics_report(self):
67
+ self.kv_cache_manager.metrics.summary()
68
+ self.kv_cache_manager.metrics.reset()
69
+
70
+ def batch_get(
71
+ self,
72
+ keys: List[str],
73
+ target_locations: List[torch.Tensor],
74
+ target_sizes: Optional[Any] = None,
75
+ ) -> List[torch.Tensor | None]:
76
+ block_hash = BlockHashes(keys, self.page_size)
77
+ status = self.kv_cache_manager.acquire(None, block_hash)
78
+ log_every_n_seconds(
79
+ logger, logging.INFO, self._aibrix_kvcache_metrics_report(), 1
80
+ )
81
+ if status.is_ok():
82
+ num_fetched_tokens, handle = status.value
83
+ kv_blocks = handle.to_tensors()
84
+ assert len(kv_blocks) == len(target_locations)
85
+ for i in range(len(kv_blocks)):
86
+ assert (
87
+ target_locations[i].nbytes == kv_blocks[i].nbytes
88
+ ), f"{target_locations[i].nbytes}, {kv_blocks[i].nbytes}"
89
+ target_locations[i].copy_(kv_blocks[i].flatten())
90
+ handle.release()
91
+ return target_locations
92
+
93
+ return [None] * len(keys)
94
+
95
+ def get(
96
+ self,
97
+ key: str,
98
+ target_location: Optional[Any] = None,
99
+ target_size: Optional[Any] = None,
100
+ ) -> torch.Tensor | None:
101
+ return self.batch_get([key], [target_location], [target_size])[0]
102
+
103
+ def batch_set(
104
+ self,
105
+ keys: List[str],
106
+ values: Optional[Any] = None,
107
+ target_locations: Optional[Any] = None,
108
+ target_sizes: Optional[Any] = None,
109
+ ) -> bool:
110
+ block_hash = BlockHashes(keys, self.page_size)
111
+ status = self.kv_cache_manager.allocate_for(None, block_hash)
112
+ if not status.is_ok():
113
+ logger.warning(
114
+ f"aibrix_kvcache set allocate failed, error_code {status.error_code}"
115
+ )
116
+ return False
117
+ handle = status.value
118
+ tensors = handle.to_tensors()
119
+ if len(tensors) != len(values):
120
+ logger.warning("aibrix_kvcache set allocate not enough")
121
+ return False
122
+ for i in range(len(tensors)):
123
+ assert (
124
+ tensors[i].nbytes == values[i].nbytes
125
+ ), f"{tensors[i].nbytes}, {values[i].nbytes}"
126
+ tensors[i].reshape(values[i].shape).copy_(values[i]).reshape(
127
+ tensors[i].shape
128
+ )
129
+ status = self.kv_cache_manager.put(None, block_hash, handle)
130
+ if not status.is_ok():
131
+ logger.info(
132
+ f"AIBrix KVCache Storage set failed, error_code {status.error_code}"
133
+ )
134
+ return False
135
+ completed = status.value
136
+ return completed == len(keys) * self.page_size
137
+
138
+ def set(
139
+ self,
140
+ key: str,
141
+ value: Optional[Any] = None,
142
+ target_location: Optional[Any] = None,
143
+ target_size: Optional[Any] = None,
144
+ ) -> bool:
145
+ return self.batch_set([key], [value], [target_location], [target_size])
146
+
147
+ def batch_exists(
148
+ self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None
149
+ ) -> int:
150
+ block_hash = BlockHashes(keys, self.page_size)
151
+ status = self.kv_cache_manager.exists(None, block_hash)
152
+ if status.is_ok():
153
+ return status.value // self.page_size
154
+ return 0
155
+
156
+ def exists(self, key: str) -> bool | dict:
157
+ return self.batch_exists([key]) > 0
@@ -0,0 +1,97 @@
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ import torch.distributed
6
+ from aibrix_kvcache.common.absl_logging import log_every_n_seconds
7
+ from aibrix_kvcache_storage import AibrixKVCacheStorage
8
+
9
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
10
+ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
11
+ from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
12
+
13
+ logging.basicConfig(
14
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
15
+ )
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def setup():
21
+ os.environ["RANK"] = "0"
22
+ os.environ["WORLD_SIZE"] = "1"
23
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
24
+ os.environ["MASTER_PORT"] = "63886"
25
+
26
+
27
+ class AIBrixKVCacheStorageTest:
28
+ def test_with_page_size(self):
29
+ config = HiCacheStorageConfig(
30
+ tp_rank=0,
31
+ tp_size=1,
32
+ is_mla_model=False,
33
+ is_page_first_layout=True,
34
+ model_name="test",
35
+ )
36
+ for page_size in range(1, 3):
37
+ logger.info(f"page_size: {page_size}")
38
+ batch_size = 2
39
+ head_num = 1
40
+ layer_num = 64
41
+ head_dim = 128
42
+ kv_cache = MHATokenToKVPool(
43
+ 1024,
44
+ page_size,
45
+ torch.float16,
46
+ head_num,
47
+ head_dim,
48
+ layer_num,
49
+ "cpu",
50
+ False,
51
+ 0,
52
+ layer_num,
53
+ )
54
+ mem_pool = MHATokenToKVPoolHost(kv_cache, 2, 0, page_size, "layer_first")
55
+ query_length = batch_size * 2
56
+ partial = batch_size
57
+ self.aibrix_kvcache = AibrixKVCacheStorage(config, mem_pool)
58
+ target_shape = (2, layer_num, page_size, head_num, head_dim)
59
+ rand_tensor = [
60
+ torch.rand(target_shape, dtype=torch.float16)
61
+ for _ in range(query_length)
62
+ ]
63
+ keys = ["hash" + str(i) for i in range(query_length)]
64
+ partial_keys = keys[batch_size:query_length]
65
+ assert self.aibrix_kvcache.batch_exists(keys) == 0
66
+ assert self.aibrix_kvcache.batch_set(keys, rand_tensor)
67
+ get_tensor = [
68
+ torch.rand(target_shape, dtype=torch.float16).flatten()
69
+ for _ in range(query_length)
70
+ ]
71
+ self.aibrix_kvcache.batch_get(keys, get_tensor)
72
+ for i in range(query_length):
73
+ assert torch.equal(get_tensor[i], rand_tensor[i].flatten())
74
+ ret = self.aibrix_kvcache.batch_exists(keys)
75
+ assert self.aibrix_kvcache.batch_exists(keys) == query_length
76
+ assert self.aibrix_kvcache.batch_exists(partial_keys) == partial
77
+ partial_get_tensor = [
78
+ torch.rand(target_shape, dtype=torch.float16).flatten()
79
+ for _ in range(partial)
80
+ ]
81
+ self.aibrix_kvcache.batch_get(partial_keys, partial_get_tensor)
82
+ for i in range(partial):
83
+ assert torch.equal(
84
+ partial_get_tensor[i], rand_tensor[i + partial].flatten()
85
+ )
86
+ log_every_n_seconds(
87
+ logger,
88
+ logging.INFO,
89
+ self.aibrix_kvcache.kv_cache_manager.metrics.summary(),
90
+ 1,
91
+ )
92
+
93
+
94
+ if __name__ == "__main__":
95
+ setup()
96
+ test = AIBrixKVCacheStorageTest()
97
+ test.test_with_page_size()
@@ -0,0 +1,223 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to SGLang project
3
+
4
+ import importlib
5
+ import logging
6
+ from typing import TYPE_CHECKING, Any, Dict
7
+
8
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
9
+
10
+ if TYPE_CHECKING:
11
+ pass
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class StorageBackendFactory:
17
+ """Factory for creating storage backend instances with support for dynamic loading."""
18
+
19
+ _registry: Dict[str, Dict[str, Any]] = {}
20
+
21
+ @staticmethod
22
+ def _load_backend_class(
23
+ module_path: str, class_name: str, backend_name: str
24
+ ) -> type[HiCacheStorage]:
25
+ """Load and validate a backend class from module path."""
26
+ try:
27
+ module = importlib.import_module(module_path)
28
+ backend_class = getattr(module, class_name)
29
+ if not issubclass(backend_class, HiCacheStorage):
30
+ raise TypeError(
31
+ f"Backend class {class_name} must inherit from HiCacheStorage"
32
+ )
33
+ return backend_class
34
+ except ImportError as e:
35
+ raise ImportError(
36
+ f"Failed to import backend '{backend_name}' from '{module_path}': {e}"
37
+ ) from e
38
+ except AttributeError as e:
39
+ raise AttributeError(
40
+ f"Class '{class_name}' not found in module '{module_path}': {e}"
41
+ ) from e
42
+
43
+ @classmethod
44
+ def register_backend(cls, name: str, module_path: str, class_name: str) -> None:
45
+ """Register a storage backend with lazy loading.
46
+
47
+ Args:
48
+ name: Backend identifier
49
+ module_path: Python module path containing the backend class
50
+ class_name: Name of the backend class
51
+ """
52
+ if name in cls._registry:
53
+ logger.warning(f"Backend '{name}' is already registered, overwriting")
54
+
55
+ def loader() -> type[HiCacheStorage]:
56
+ """Lazy loader function to import the backend class."""
57
+ return cls._load_backend_class(module_path, class_name, name)
58
+
59
+ cls._registry[name] = {
60
+ "loader": loader,
61
+ "module_path": module_path,
62
+ "class_name": class_name,
63
+ }
64
+
65
+ @classmethod
66
+ def create_backend(
67
+ cls,
68
+ backend_name: str,
69
+ storage_config: HiCacheStorageConfig,
70
+ mem_pool_host: Any,
71
+ **kwargs,
72
+ ) -> HiCacheStorage:
73
+ """Create a storage backend instance.
74
+ Args:
75
+ backend_name: Name of the backend to create
76
+ storage_config: Storage configuration
77
+ mem_pool_host: Memory pool host object
78
+ **kwargs: Additional arguments passed to external backends
79
+ Returns:
80
+ Initialized storage backend instance
81
+ Raises:
82
+ ValueError: If backend is not registered and cannot be dynamically loaded
83
+ ImportError: If backend module cannot be imported
84
+ Exception: If backend initialization fails
85
+ """
86
+ # First check if backend is already registered
87
+ if backend_name in cls._registry:
88
+ registry_entry = cls._registry[backend_name]
89
+ backend_class = registry_entry["loader"]()
90
+ logger.info(
91
+ f"Creating storage backend '{backend_name}' "
92
+ f"({registry_entry['module_path']}.{registry_entry['class_name']})"
93
+ )
94
+ return cls._create_builtin_backend(
95
+ backend_name, backend_class, storage_config, mem_pool_host
96
+ )
97
+
98
+ # Try to dynamically load backend from extra_config
99
+ if backend_name == "dynamic" and storage_config.extra_config is not None:
100
+ backend_config = storage_config.extra_config
101
+ return cls._create_dynamic_backend(
102
+ backend_config, storage_config, mem_pool_host, **kwargs
103
+ )
104
+
105
+ # Backend not found
106
+ available_backends = list(cls._registry.keys())
107
+
108
+ raise ValueError(
109
+ f"Unknown storage backend '{backend_name}'. "
110
+ f"Registered backends: {available_backends}. "
111
+ )
112
+
113
+ @classmethod
114
+ def _create_dynamic_backend(
115
+ cls,
116
+ backend_config: Dict[str, Any],
117
+ storage_config: HiCacheStorageConfig,
118
+ mem_pool_host: Any,
119
+ **kwargs,
120
+ ) -> HiCacheStorage:
121
+ """Create a backend dynamically from configuration."""
122
+ required_fields = ["backend_name", "module_path", "class_name"]
123
+ for field in required_fields:
124
+ if field not in backend_config:
125
+ raise ValueError(
126
+ f"Missing required field '{field}' in backend config for 'dynamic' backend"
127
+ )
128
+
129
+ backend_name = backend_config["backend_name"]
130
+ module_path = backend_config["module_path"]
131
+ class_name = backend_config["class_name"]
132
+
133
+ try:
134
+ # Import the backend class
135
+ backend_class = cls._load_backend_class(
136
+ module_path, class_name, backend_name
137
+ )
138
+
139
+ logger.info(
140
+ f"Creating dynamic storage backend '{backend_name}' "
141
+ f"({module_path}.{class_name})"
142
+ )
143
+
144
+ # Create the backend instance with storage_config
145
+ return backend_class(storage_config, kwargs)
146
+ except Exception as e:
147
+ logger.error(
148
+ f"Failed to create dynamic storage backend '{backend_name}': {e}"
149
+ )
150
+ raise
151
+
152
+ @classmethod
153
+ def _create_builtin_backend(
154
+ cls,
155
+ backend_name: str,
156
+ backend_class: type[HiCacheStorage],
157
+ storage_config: HiCacheStorageConfig,
158
+ mem_pool_host: Any,
159
+ ) -> HiCacheStorage:
160
+ """Create built-in backend with original initialization logic."""
161
+ if backend_name == "file":
162
+ return backend_class(storage_config)
163
+ elif backend_name == "nixl":
164
+ return backend_class(storage_config)
165
+ elif backend_name == "mooncake":
166
+ backend = backend_class(storage_config)
167
+ return backend
168
+ elif backend_name == "aibrix":
169
+ backend = backend_class(storage_config, mem_pool_host)
170
+ return backend
171
+ elif backend_name == "hf3fs":
172
+ # Calculate bytes_per_page based on memory pool layout
173
+ if mem_pool_host.layout in ["page_first", "page_first_direct"]:
174
+ bytes_per_page = (
175
+ mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
176
+ )
177
+ elif mem_pool_host.layout == "layer_first":
178
+ bytes_per_page = (
179
+ mem_pool_host.get_size_per_token() * mem_pool_host.page_size
180
+ )
181
+
182
+ dtype = mem_pool_host.dtype
183
+ return backend_class.from_env_config(bytes_per_page, dtype, storage_config)
184
+ elif backend_name == "eic":
185
+ return backend_class(storage_config, mem_pool_host)
186
+ else:
187
+ raise ValueError(f"Unknown built-in backend: {backend_name}")
188
+
189
+
190
+ # Register built-in storage backends
191
+ StorageBackendFactory.register_backend(
192
+ "file", "sglang.srt.mem_cache.hicache_storage", "HiCacheFile"
193
+ )
194
+
195
+ StorageBackendFactory.register_backend(
196
+ "nixl",
197
+ "sglang.srt.mem_cache.storage.nixl.hicache_nixl",
198
+ "HiCacheNixl",
199
+ )
200
+
201
+ StorageBackendFactory.register_backend(
202
+ "mooncake",
203
+ "sglang.srt.mem_cache.storage.mooncake_store.mooncake_store",
204
+ "MooncakeStore",
205
+ )
206
+
207
+ StorageBackendFactory.register_backend(
208
+ "hf3fs",
209
+ "sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs",
210
+ "HiCacheHF3FS",
211
+ )
212
+
213
+ StorageBackendFactory.register_backend(
214
+ "aibrix",
215
+ "sglang.srt.mem_cache.storage.aibrix_kvcache.aibrix_kvcache_storage",
216
+ "AibrixKVCacheStorage",
217
+ )
218
+
219
+ StorageBackendFactory.register_backend(
220
+ "eic",
221
+ "sglang.srt.mem_cache.storage.eic.eic_storage",
222
+ "EICStorage",
223
+ )