sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,115 @@
1
+ import argparse
2
+ import os
3
+
4
+ import eic
5
+ import torch
6
+ import yaml
7
+
8
+
9
+ def pase_args():
10
+ parser = argparse.ArgumentParser(description="EIC Storage Unit Test")
11
+ parser.add_argument(
12
+ "--config",
13
+ "-c",
14
+ type=str,
15
+ default="/sgl-workspace/config/remote-eic.yaml",
16
+ help="EIC yaml config",
17
+ )
18
+ args, _ = parser.parse_known_args()
19
+ return args
20
+
21
+
22
+ def init_eic_client():
23
+ args = pase_args()
24
+ config_path = os.path.abspath(args.config)
25
+ if not os.path.exists(config_path):
26
+ raise FileNotFoundError(f"Config file not found: {config_path}")
27
+ with open(config_path, "r") as fin:
28
+ config = yaml.safe_load(fin)
29
+
30
+ remote_url = config.get("remote_url", None)
31
+ if remote_url is None:
32
+ AssertionError("remote_url is None")
33
+ endpoint = remote_url[len("eic://") :]
34
+ eic_instance_id = config.get("eic_instance_id", None)
35
+ eic_log_dir = config.get("eic_log_dir", None)
36
+ eic_log_level = config.get("eic_log_level", 2)
37
+ eic_trans_type = config.get("eic_trans_type", 3)
38
+ eic_flag_file = config.get("eic_flag_file", None)
39
+
40
+ if not os.path.exists(eic_log_dir):
41
+ os.makedirs(eic_log_dir, exist_ok=True)
42
+ eic_client = eic.Client()
43
+ init_option = eic.InitOption()
44
+ init_option.log_dir = eic_log_dir
45
+ init_option.log_level = eic.LogLevel(eic_log_level)
46
+ init_option.transport_type = eic.TransportType(eic_trans_type)
47
+ init_option.flag_file = eic_flag_file
48
+ ret = eic_client.init(eic_instance_id, endpoint, init_option)
49
+ if ret != 0:
50
+ raise RuntimeError(f"EIC Client init failed with error code: {ret}")
51
+ return eic_client
52
+
53
+
54
+ def test_set(eic_client):
55
+ test_key = ["test_key_" + str(i) for i in range(16)]
56
+ tensors = [
57
+ torch.ones([12, 6, 1, 512], dtype=torch.bfloat16, device="cpu")
58
+ for _ in range(16)
59
+ ]
60
+ data_keys = eic.StringVector()
61
+ data_vals = eic.IOBuffers()
62
+ for i in range(16):
63
+ data_keys.append(test_key[i])
64
+ data_vals.append(
65
+ tensors[i].data_ptr(), tensors[i].numel() * tensors[i].element_size(), False
66
+ )
67
+ set_opt = eic.SetOption()
68
+ set_opt.ttl_second = 3
69
+ status_code, set_outcome = eic_client.mset(data_keys, data_vals, set_opt)
70
+ assert (
71
+ status_code == eic.StatusCode.SUCCESS
72
+ ), f"Set failed with status code: {status_code}"
73
+
74
+
75
+ def test_get(eic_client):
76
+ test_key = ["test_key_" + str(i) for i in range(16)]
77
+ tensors = [
78
+ torch.zeros([12, 6, 1, 512], dtype=torch.bfloat16, device="cpu")
79
+ for _ in range(16)
80
+ ]
81
+ data_keys = eic.StringVector()
82
+ data_vals = eic.IOBuffers()
83
+ for i in range(16):
84
+ data_keys.append(test_key[i])
85
+ data_vals.append(
86
+ tensors[i].data_ptr(), tensors[i].numel() * tensors[i].element_size(), False
87
+ )
88
+ get_opt = eic.GetOption()
89
+ status_code, data_vals, get_outcome = eic_client.mget(data_keys, get_opt, data_vals)
90
+ assert (
91
+ status_code == eic.StatusCode.SUCCESS
92
+ ), f"Get failed with status code: {status_code}"
93
+
94
+
95
+ def test_exists(eic_client):
96
+ test_key = ["test_key_" + str(i) for i in range(16)]
97
+ data_keys = eic.StringVector()
98
+ for key in test_key:
99
+ data_keys.append(key)
100
+ exists_opt = eic.ExistOption()
101
+ status_code, exists_outcome = eic_client.mexist(data_keys, exists_opt)
102
+ assert (
103
+ status_code == eic.StatusCode.SUCCESS
104
+ ), f"Exists failed with status code: {status_code}"
105
+
106
+
107
+ def main():
108
+ eic_client = init_eic_client()
109
+ test_set(eic_client)
110
+ test_exists(eic_client)
111
+ test_get(eic_client)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
@@ -1,6 +1,5 @@
1
1
  import logging
2
2
  import os
3
- import threading
4
3
  from abc import ABC, abstractmethod
5
4
  from typing import List
6
5
 
@@ -12,7 +12,12 @@ from typing import Any, List, Optional, Tuple
12
12
 
13
13
  import torch
14
14
 
15
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
15
+ from sglang.srt.mem_cache.hicache_storage import (
16
+ HiCacheStorage,
17
+ HiCacheStorageConfig,
18
+ HiCacheStorageExtraInfo,
19
+ )
20
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
16
21
  from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
17
22
  from sglang.srt.metrics.collector import StorageMetrics
18
23
 
@@ -178,11 +183,14 @@ class HiCacheHF3FS(HiCacheStorage):
178
183
  self.skip_backup = True
179
184
  self.rank = 0
180
185
 
186
+ self.is_zero_copy = False
187
+
181
188
  logger.info(
182
189
  f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
183
190
  f"file_path={self.file_path}, "
184
191
  f"file_size={self.file_size / (2 ** 30):.2f} GB, "
185
- f"num_pages={self.num_pages}"
192
+ f"num_pages={self.num_pages}, "
193
+ f"is_mla_model={self.is_mla_model}"
186
194
  )
187
195
 
188
196
  self.ac = AtomicCounter(self.numjobs)
@@ -323,25 +331,12 @@ class HiCacheHF3FS(HiCacheStorage):
323
331
  use_mock_client=use_mock_client,
324
332
  )
325
333
 
326
- def get(
327
- self,
328
- key: str,
329
- target_location: Optional[Any] = None,
330
- target_sizes: Optional[Any] = None,
331
- ) -> torch.Tensor | None:
332
- return self.batch_get(
333
- [key],
334
- [target_location] if target_location is not None else None,
335
- [target_sizes] if target_sizes is not None else None,
336
- )[0]
337
-
338
334
  @synchronized()
339
- def batch_get(
335
+ def _batch_get(
340
336
  self,
341
337
  keys: List[str],
342
- target_locations: Optional[Any] = None,
343
- target_sizes: Optional[Any] = None,
344
- ) -> List[torch.Tensor | None]:
338
+ values: List[torch.Tensor],
339
+ ) -> List[bool]:
345
340
  page_indices = self.metadata_client.get_page_indices(self.rank, keys)
346
341
 
347
342
  batch_indices, file_offsets = [], []
@@ -350,15 +345,9 @@ class HiCacheHF3FS(HiCacheStorage):
350
345
  batch_indices.append(i)
351
346
  file_offsets.append(page_index * self.bytes_per_page)
352
347
 
353
- if target_locations is not None:
354
- for target_location in target_locations:
355
- assert target_location.is_contiguous()
356
- file_results = target_locations
357
- else:
358
- file_results = [
359
- torch.empty(self.numel, dtype=self.dtype)
360
- for _ in range(len(batch_indices))
361
- ]
348
+ for target_location in values:
349
+ assert target_location.is_contiguous()
350
+ file_results = values
362
351
 
363
352
  start_time = time.perf_counter()
364
353
 
@@ -379,12 +368,10 @@ class HiCacheHF3FS(HiCacheStorage):
379
368
  ionum / (end_time - start_time) * self.gb_per_page
380
369
  )
381
370
 
382
- results = [None] * len(keys)
383
- for batch_index, file_result, read_result in zip(
384
- batch_indices, file_results, read_results
385
- ):
371
+ results = [False] * len(keys)
372
+ for batch_index, read_result in zip(batch_indices, read_results):
386
373
  if read_result == self.bytes_per_page:
387
- results[batch_index] = file_result
374
+ results[batch_index] = True
388
375
  else:
389
376
  logger.error(
390
377
  f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
@@ -392,28 +379,12 @@ class HiCacheHF3FS(HiCacheStorage):
392
379
 
393
380
  return results
394
381
 
395
- def set(
396
- self,
397
- key: str,
398
- value: Optional[Any] = None,
399
- target_location: Optional[Any] = None,
400
- target_sizes: Optional[Any] = None,
401
- ) -> bool:
402
- return self.batch_set(
403
- [key],
404
- [value] if value is not None else None,
405
- [target_location] if target_location is not None else None,
406
- [target_sizes] if target_sizes is not None else None,
407
- )
408
-
409
382
  @synchronized()
410
- def batch_set(
383
+ def _batch_set(
411
384
  self,
412
385
  keys: List[str],
413
386
  values: Optional[Any] = None,
414
- target_locations: Optional[Any] = None,
415
- target_sizes: Optional[Any] = None,
416
- ) -> bool:
387
+ ) -> List[bool]:
417
388
  # In MLA backend, only one rank needs to backup the KV cache
418
389
  if self.skip_backup:
419
390
  return True
@@ -474,7 +445,7 @@ class HiCacheHF3FS(HiCacheStorage):
474
445
  self.rank, written_keys_to_confirm, pages_to_release
475
446
  )
476
447
 
477
- return all(results)
448
+ return results
478
449
 
479
450
  def delete(self, key: str) -> None:
480
451
  self.metadata_client.delete_keys(self.rank, [key])
@@ -483,22 +454,28 @@ class HiCacheHF3FS(HiCacheStorage):
483
454
  result = self.metadata_client.exists(self.rank, [key])
484
455
  return result[0] if result else False
485
456
 
486
- def batch_exists(self, keys: List[str]) -> int:
457
+ def batch_exists(
458
+ self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None
459
+ ) -> int:
460
+ factor = 1
461
+ if self.is_zero_copy and not self.is_mla_model:
462
+ keys = self._get_mha_zero_copy_keys(keys)
463
+ factor = 2
464
+
487
465
  results = self.metadata_client.exists(self.rank, keys)
488
- for i in range(len(keys)):
489
- if not results[i]:
490
- return i
491
466
 
492
- return len(keys)
467
+ i = 0
468
+ while i < len(keys) and results[i]:
469
+ i += 1
493
470
 
494
- def clear(self) -> bool:
471
+ return i // factor
472
+
473
+ def clear(self) -> None:
495
474
  try:
496
475
  self.metadata_client.clear(self.rank)
497
476
  logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
498
- return True
499
477
  except Exception as e:
500
478
  logger.error(f"Failed to clear HiCacheHF3FS: {e}")
501
- return False
502
479
 
503
480
  def close(self) -> None:
504
481
  try:
@@ -521,3 +498,147 @@ class HiCacheHF3FS(HiCacheStorage):
521
498
  self.prefetch_bandwidth.clear()
522
499
  self.backup_bandwidth.clear()
523
500
  return storage_metrics
501
+
502
+ def register_mem_pool_host(self, mem_pool_host: HostKVCache):
503
+ super().register_mem_pool_host(mem_pool_host)
504
+ self.is_zero_copy = self.mem_pool_host.layout in [
505
+ "page_first",
506
+ "page_first_direct",
507
+ ]
508
+
509
+ logger.info(f"{self.is_zero_copy=}, layout={self.mem_pool_host.layout}")
510
+
511
+ def _get_mha_zero_copy_keys(self, keys: List[str]) -> List[str]:
512
+ _keys = []
513
+ for k in keys:
514
+ _keys.append(f"{k}-k")
515
+ _keys.append(f"{k}-v")
516
+ return _keys
517
+
518
+ def _get_mha_zero_copy_values(
519
+ self, values: List[torch.Tensor]
520
+ ) -> List[torch.Tensor]:
521
+ _values = []
522
+ for value in values:
523
+ _values.append(value[0])
524
+ _values.append(value[1])
525
+ return _values
526
+
527
+ def _batch_get_preprocess(self, keys, host_indices):
528
+ page_num = len(host_indices) // self.mem_pool_host.page_size
529
+ # host_indices to kv_buffer
530
+ flat = not self.is_zero_copy
531
+ values = (
532
+ [
533
+ self.mem_pool_host.get_data_page(
534
+ host_indices[i * self.mem_pool_host.page_size], flat=flat
535
+ )
536
+ for i in range(page_num)
537
+ ]
538
+ if self.is_zero_copy
539
+ else [
540
+ self.mem_pool_host.get_dummy_flat_data_page() for _ in range(page_num)
541
+ ]
542
+ )
543
+
544
+ if self.is_zero_copy and not self.is_mla_model:
545
+ keys = self._get_mha_zero_copy_keys(keys)
546
+ values = self._get_mha_zero_copy_values(values)
547
+
548
+ return keys, values
549
+
550
+ def _batch_get_postprocess(self, host_indices, values, results):
551
+ page_num = len(host_indices) // self.mem_pool_host.page_size
552
+
553
+ if self.is_zero_copy:
554
+ if not self.is_mla_model:
555
+ results = [
556
+ (results[2 * i] and results[2 * i + 1]) for i in range(page_num)
557
+ ]
558
+ results = results[:page_num]
559
+ return results
560
+
561
+ for i in range(page_num):
562
+ if not results[i]:
563
+ break
564
+ self.mem_pool_host.set_from_flat_data_page(
565
+ host_indices[i * self.mem_pool_host.page_size], values[i]
566
+ )
567
+
568
+ return results
569
+
570
+ def batch_get_v1(
571
+ self,
572
+ keys: List[str],
573
+ host_indices: torch.Tensor,
574
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
575
+ ) -> List[bool]:
576
+ keys, values = self._batch_get_preprocess(keys, host_indices)
577
+ results = self._batch_get(keys, values)
578
+ return self._batch_get_postprocess(host_indices, values, results)
579
+
580
+ def _batch_set_preprocess(self, keys, host_indices):
581
+ page_num = len(host_indices) // self.mem_pool_host.page_size
582
+ # host_indices to kv_buffer
583
+ flat = not self.is_zero_copy
584
+ values = [
585
+ self.mem_pool_host.get_data_page(
586
+ host_indices[i * self.mem_pool_host.page_size], flat=flat
587
+ )
588
+ for i in range(page_num)
589
+ ]
590
+
591
+ if self.is_zero_copy and not self.is_mla_model:
592
+ keys = self._get_mha_zero_copy_keys(keys)
593
+ values = self._get_mha_zero_copy_values(values)
594
+
595
+ return keys, values
596
+
597
+ def batch_set_v1(
598
+ self,
599
+ keys: List[str],
600
+ host_indices: torch.Tensor,
601
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
602
+ ) -> List[bool]:
603
+ len_keys = len(keys)
604
+ keys, values = self._batch_set_preprocess(keys, host_indices)
605
+ results = self._batch_set(keys, values)
606
+ return results
607
+
608
+ # Deprecated
609
+ def get(
610
+ self,
611
+ key: str,
612
+ target_location: Optional[Any] = None,
613
+ target_sizes: Optional[Any] = None,
614
+ ) -> torch.Tensor | None:
615
+ pass
616
+
617
+ # Deprecated
618
+ def batch_get(
619
+ self,
620
+ keys: List[str],
621
+ target_locations: Optional[Any] = None,
622
+ target_sizes: Optional[Any] = None,
623
+ ) -> List[torch.Tensor | None] | int:
624
+ pass
625
+
626
+ # Deprecated
627
+ def set(
628
+ self,
629
+ key: str,
630
+ value: Optional[Any] = None,
631
+ target_location: Optional[Any] = None,
632
+ target_sizes: Optional[Any] = None,
633
+ ) -> bool:
634
+ pass
635
+
636
+ # Deprecated
637
+ def batch_set(
638
+ self,
639
+ keys: List[str],
640
+ values: Optional[Any] = None,
641
+ target_locations: Optional[Any] = None,
642
+ target_sizes: Optional[Any] = None,
643
+ ) -> bool:
644
+ pass
@@ -2,14 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import threading
5
- from typing import TYPE_CHECKING, List, Optional
5
+ from typing import TYPE_CHECKING, Optional
6
6
 
7
7
  import torch
8
8
 
9
9
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
10
10
  from sglang.srt.mem_cache.base_prefix_cache import MatchResult
11
11
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
12
- from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
12
+ from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
13
13
 
14
14
  try:
15
15
  from lmcache.integration.sglang.sglang_adapter import (
@@ -78,6 +78,7 @@ class LMCRadixCache(RadixCache):
78
78
  tp_size: int = 1,
79
79
  rank: int = 0,
80
80
  tp_group: Optional[torch.distributed.ProcessGroup] = None,
81
+ eviction_policy: str = "lru",
81
82
  ):
82
83
  super().__init__(
83
84
  req_to_token_pool=req_to_token_pool,
@@ -85,6 +86,7 @@ class LMCRadixCache(RadixCache):
85
86
  page_size=page_size,
86
87
  disable=disable,
87
88
  enable_kv_cache_events=enable_kv_cache_events,
89
+ eviction_policy=eviction_policy,
88
90
  )
89
91
 
90
92
  kvcache = self.token_to_kv_pool_allocator.get_kvcache()
@@ -129,7 +131,7 @@ class LMCRadixCache(RadixCache):
129
131
  with self._node_lock:
130
132
  self._in_flight_nodes.clear()
131
133
 
132
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override]
134
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # type: ignore[override]
133
135
  """Match cached prefix; if there's a tail miss, prefetch from LMCache.
134
136
 
135
137
  Reuses the base matching logic to obtain (value, last_node). If there
@@ -176,7 +178,7 @@ class LMCRadixCache(RadixCache):
176
178
  with torch.cuda.stream(self.load_stream):
177
179
  num_retrieved = self.lmcache_connector.start_load_kv(
178
180
  LoadMetadata(
179
- token_ids=key, # full page-aligned key
181
+ token_ids=key.token_ids, # full page-aligned key
180
182
  slot_mapping=slot_mapping,
181
183
  offset=value.numel() - prefix_pad, # LMCache offset convention
182
184
  )
@@ -215,17 +217,19 @@ class LMCRadixCache(RadixCache):
215
217
 
216
218
  return base_res
217
219
 
218
- def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
220
+ def cache_finished_req(self, req: "Req", is_insert: bool = True) -> None: # type: ignore[override]
219
221
  """On request completion, insert device KV into radix and store to LMCache."""
220
222
 
221
- super().cache_finished_req(req)
223
+ super().cache_finished_req(req, is_insert=is_insert)
224
+ if not is_insert:
225
+ return
222
226
 
223
227
  token_ids = (req.origin_input_ids + req.output_ids)[:-1]
224
228
  kv_indices = self.req_to_token_pool.req_to_token[
225
229
  req.req_pool_idx, : len(token_ids)
226
230
  ]
227
231
 
228
- _, new_last_node, _, _ = self.match_prefix(token_ids)
232
+ _, new_last_node, _, _ = self.match_prefix(RadixKey(token_ids, req.extra_key))
229
233
  assert new_last_node is not None
230
234
 
231
235
  self.inc_lock_ref(new_last_node)
@@ -275,6 +279,8 @@ if __name__ == "__main__":
275
279
  rank=0,
276
280
  tp_group=None,
277
281
  )
278
- cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64))
279
- cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64))
282
+ cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 11, 12], dtype=torch.int64))
283
+ cache.insert(
284
+ RadixKey([1, 2, 3, 4]), torch.tensor([10, 11, 12, 13], dtype=torch.int64)
285
+ )
280
286
  cache.pretty_print()