sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,777 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ import time
5
+ from typing import Any, List, Optional, Tuple
6
+
7
+ import eic
8
+ import torch
9
+ import yaml
10
+
11
+ from sglang.srt.mem_cache.hicache_storage import (
12
+ HiCacheStorage,
13
+ HiCacheStorageConfig,
14
+ HiCacheStorageExtraInfo,
15
+ )
16
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ TensorPoolSize = 2048
22
+
23
+ REMOTE_EIC_YAML_ENV_VAR = "REMOTE_EIC_YAML"
24
+
25
+ # gpu direct rdma for kv set
26
+ G_EnableKVSetGPUDirect = False
27
+
28
+ # gpu direct rdma for kv get
29
+ G_EnableKVGetGPUDirect = False
30
+
31
+ # gpu nic affinity
32
+ G_EnableGPUNicAffinity = False
33
+
34
+ # default H20 gpu nic affinity
35
+ GPUNicAffinity = {
36
+ "cuda:0": "eth1",
37
+ "cuda:1": "eth1",
38
+ "cuda:2": "eth2",
39
+ "cuda:3": "eth2",
40
+ "cuda:4": "eth3",
41
+ "cuda:5": "eth3",
42
+ "cuda:6": "eth4",
43
+ "cuda:7": "eth4",
44
+ }
45
+
46
+ # default H20 cpu nic affinity
47
+ CPUNicAffinity = {
48
+ "cuda:0": "cpu",
49
+ "cuda:1": "cpu",
50
+ "cuda:2": "cpu",
51
+ "cuda:3": "cpu",
52
+ "cuda:4": "cpu",
53
+ "cuda:5": "cpu",
54
+ "cuda:6": "cpu",
55
+ "cuda:7": "cpu",
56
+ }
57
+
58
+
59
+ def get_eic_config_file_path():
60
+ if os.environ.get(REMOTE_EIC_YAML_ENV_VAR) is not None:
61
+ logger.info(f"eic init with env var {REMOTE_EIC_YAML_ENV_VAR}")
62
+ config_file = os.environ.get(REMOTE_EIC_YAML_ENV_VAR)
63
+ else:
64
+ config_file = "/sgl-workspace/config/remote-eic.yaml"
65
+ logger.info(f"eic init with default config, config_file {config_file}")
66
+ return config_file
67
+
68
+
69
+ class FlexibleKVCacheMemoryPool:
70
+ def __init__(self, conn, kvcache_shape, kvcache_dtype, device):
71
+ self.connection = conn
72
+
73
+ if device.startswith("cpu") and G_EnableGPUNicAffinity:
74
+ gpu_id = torch.cuda.current_device()
75
+ self.device = CPUNicAffinity["cuda:" + str(gpu_id)]
76
+ # current memory pool size is 5 times of CPU TensorPoolSize
77
+ mempool_size = TensorPoolSize * 5
78
+ else:
79
+ self.device = device
80
+ mempool_size = TensorPoolSize
81
+
82
+ self.kvcache_shape = kvcache_shape
83
+ self.kvcache_dtype = kvcache_dtype
84
+
85
+ self.kv_cache_numel = 1
86
+ for i in self.kvcache_shape:
87
+ self.kv_cache_numel *= i
88
+
89
+ self.free_data_addr = set()
90
+ self.data_ptr_to_index = dict()
91
+
92
+ if self.device.startswith("cpu"):
93
+ self.kvcache_mempool = torch.zeros(
94
+ (mempool_size,) + kvcache_shape,
95
+ dtype=kvcache_dtype,
96
+ device=self.device,
97
+ pin_memory=True,
98
+ )
99
+ else:
100
+ self.kvcache_mempool = torch.zeros(
101
+ (mempool_size,) + kvcache_shape, dtype=kvcache_dtype, device=self.device
102
+ )
103
+
104
+ for i in range(mempool_size):
105
+ self.free_data_addr.add(i)
106
+ self.data_ptr_to_index[self.kvcache_mempool[i].data_ptr()] = i
107
+
108
+ meminfo = eic.MemoryInfo()
109
+ meminfo.type = eic.MemoryType.MEMORY_CUDA
110
+ meminfo.cuda_id = 0
111
+ vals = eic.IOBuffers()
112
+ vals.append(
113
+ self.kvcache_mempool.data_ptr(),
114
+ self.kvcache_mempool.numel() * self.kvcache_mempool.element_size(),
115
+ True,
116
+ )
117
+ self.connection.register_memory(vals, meminfo)
118
+ logger.info(
119
+ f"allocate memory pool, size {self.kvcache_mempool.numel() * self.kvcache_mempool.element_size()}, device {self.device}"
120
+ )
121
+
122
+ def try_allocate_kv_cache(self, shape, dtype, count=1):
123
+ if len(self.free_data_addr) < count:
124
+ return None
125
+
126
+ numel = 1
127
+ for i in shape:
128
+ numel *= i
129
+ if numel != self.kv_cache_numel or dtype != self.kvcache_dtype:
130
+ logger.error(
131
+ f"allocate from mempool failed, self.kvcache_shape {self.kvcache_shape}, dtype {self.kvcache_dtype}, require shape {shape}, dtype {dtype}"
132
+ )
133
+ return None
134
+
135
+ ret = []
136
+ for _ in range(count):
137
+ free_index = self.free_data_addr.pop()
138
+ ret.append(self.kvcache_mempool[free_index])
139
+ return ret
140
+
141
+ def free_to_mempool(self, data_ptr):
142
+ if data_ptr not in self.data_ptr_to_index:
143
+ logger.error(
144
+ f"free_to_mempool failed, data_ptr {data_ptr} not in allocated_data_addr"
145
+ )
146
+ return
147
+ self.free_data_addr.add(self.data_ptr_to_index[data_ptr])
148
+
149
+ def check_data_ptr_allocated(self, data_ptr):
150
+ return data_ptr in self.data_ptr_to_index
151
+
152
+ def left_count(self):
153
+ return len(self.free_data_addr)
154
+
155
+
156
+ class EICStorage(HiCacheStorage):
157
+ def __init__(
158
+ self, hicache_config: HiCacheStorageConfig, memory_pool_host: HostKVCache
159
+ ):
160
+ global G_EnableKVSetGPUDirect, G_EnableKVGetGPUDirect
161
+ global GPUNicAffinity, CPUNicAffinity, G_EnableGPUNicAffinity
162
+
163
+ config_file = get_eic_config_file_path()
164
+ if os.path.exists(config_file) is False:
165
+ logger.error(f"config file {config_file} not exists")
166
+ raise RuntimeError(f"eic config file {config_file} not exists")
167
+
168
+ with open(config_file, "r") as fin:
169
+ config = yaml.safe_load(fin)
170
+
171
+ remote_url = config.get("remote_url", None)
172
+ if remote_url is None:
173
+ AssertionError("remote_url is None")
174
+
175
+ endpoint = remote_url[len("eic://") :]
176
+
177
+ logger.info(f"eic remote_url:" + remote_url + " endpoint: " + endpoint)
178
+
179
+ eic_instance_id = config.get("eic_instance_id", None)
180
+ logger.info(f"eic instance_id: {eic_instance_id}")
181
+
182
+ eic_thread_num = config.get("eic_thread_num", 1)
183
+ logger.info(f"eic thread_num: {eic_thread_num}")
184
+
185
+ eic_log_dir = config.get("eic_log_dir", None)
186
+ logger.info(f"eic log_dir: {eic_log_dir}")
187
+
188
+ eic_log_level = config.get("eic_log_level", 2)
189
+ logger.info(f"eic log_level: {eic_log_level}")
190
+
191
+ eic_trans_type = config.get("eic_trans_type", 3)
192
+ logger.info(f"eic trans_type: {eic_trans_type}")
193
+
194
+ eic_flag_file = config.get("eic_flag_file", None)
195
+ logger.info(f"eic flag_file: {eic_flag_file}")
196
+
197
+ # GDR now is not used
198
+ G_EnableKVSetGPUDirect = (
199
+ config.get("enable_kvset_gpu_direct", False) and torch.cuda.is_available()
200
+ )
201
+ logger.debug(f"eic enable_kvset_gpu_direct: {G_EnableKVSetGPUDirect}")
202
+
203
+ G_EnableKVGetGPUDirect = (
204
+ config.get("enable_kvget_gpu_direct", False) and torch.cuda.is_available()
205
+ )
206
+ logger.debug(f"eic enable_kvget_gpu_direct: {G_EnableKVGetGPUDirect}")
207
+
208
+ self.model_name = hicache_config.model_name
209
+
210
+ # rdma
211
+ enable_kv_set_direct = config.get("enable_kvset_direct", True)
212
+ logger.info(f"eic enable_kv_set_direct: {enable_kv_set_direct}")
213
+ self.enable_kv_set_direct = enable_kv_set_direct
214
+
215
+ enable_kv_get_direct = config.get("enable_kvget_direct", True)
216
+ logger.info(f"eic enable_kv_get_direct: {enable_kv_get_direct}")
217
+ self.enable_kv_get_direct = enable_kv_get_direct
218
+
219
+ # gpu nic affinity
220
+ G_EnableGPUNicAffinity = config.get("enable_gpu_nic_affinity", False)
221
+ logger.info(f"eic enable_gpu_nic_affinity: {G_EnableGPUNicAffinity}")
222
+ self.enable_gpu_nic_affinity = G_EnableGPUNicAffinity
223
+
224
+ if G_EnableGPUNicAffinity:
225
+ if "gpu_nic_affinity_config" in config:
226
+ GPUNicAffinity = json.loads(config["gpu_nic_affinity_config"])
227
+ if "cpu_nic_affinity_config" in config:
228
+ CPUNicAffinity = json.loads(config["cpu_nic_affinity_config"])
229
+ logger.info(f"eic gpu nic affinity {GPUNicAffinity}")
230
+ logger.info(f"eic cpu nic affinity {CPUNicAffinity}")
231
+
232
+ eic_namespace = config.get("eic_namespace", "")
233
+ logger.info(f"eic namespace: {eic_namespace}")
234
+ self.eic_namespace = eic_namespace
235
+
236
+ if not os.path.exists(eic_log_dir) and not os.path.isdir(eic_log_dir):
237
+ os.makedirs(eic_log_dir, exist_ok=True)
238
+
239
+ self.connection = eic.Client()
240
+ init_option = eic.InitOption()
241
+ init_option.log_dir = eic_log_dir
242
+ init_option.log_level = eic.LogLevel(eic_log_level)
243
+ init_option.transport_type = eic.TransportType(eic_trans_type)
244
+ init_option.flag_file = eic_flag_file
245
+
246
+ if G_EnableGPUNicAffinity:
247
+ gpu_id = torch.cuda.current_device()
248
+ init_option.multi_net_local_interface_names = GPUNicAffinity[
249
+ "cuda:" + str(gpu_id)
250
+ ]
251
+ logger.info(
252
+ f"gpu {gpu_id} set gpu nic affinity to {init_option.multi_net_local_interface_names}"
253
+ )
254
+
255
+ ret = self.connection.init(eic_instance_id, endpoint, init_option)
256
+ if ret != 0:
257
+ logger.error(f"fail to init eic client, ret: {ret}")
258
+ raise RuntimeError("EIC Client Init Failed.")
259
+ self.warmup()
260
+
261
+ self.memory_pool_host = memory_pool_host
262
+ self.host_kvcache_layout = self.memory_pool_host.layout
263
+ self.trans_type = eic.TransportType(eic_trans_type)
264
+ self.kv_cache_dtype = self.memory_pool_host.dtype
265
+ self.is_mla_model = hicache_config.is_mla_model
266
+ self.rank = hicache_config.tp_rank
267
+ self.world_size = hicache_config.tp_size
268
+ self.page_size = self.memory_pool_host.page_size
269
+ self.use_zero_copy = self.memory_pool_host.layout == "page_first"
270
+ if not self.use_zero_copy:
271
+ self.kv_cache_shape = self.memory_pool_host.get_data_page(
272
+ 0, flat=True
273
+ ).shape
274
+ if self.enable_kv_set_direct:
275
+ self.kv_cache_write_mem_pool = FlexibleKVCacheMemoryPool(
276
+ self.connection, self.kv_cache_shape, self.kv_cache_dtype, "cpu"
277
+ )
278
+ if self.enable_kv_get_direct:
279
+ self.kv_cache_get_mem_pool = FlexibleKVCacheMemoryPool(
280
+ self.connection, self.kv_cache_shape, self.kv_cache_dtype, "cpu"
281
+ )
282
+ self._init_eic_prefix()
283
+
284
+ def warmup(self):
285
+ logger.info("begin warm up eic client")
286
+ start_time = time.perf_counter()
287
+ num_warmup = 1024
288
+ preheat_keys = ["warmup_key_" + str(i) for i in range(num_warmup)]
289
+ batch_size = 32
290
+ for i in range(0, num_warmup, batch_size):
291
+ keys_vec = eic.StringVector()
292
+ for key in preheat_keys[i : i + batch_size]:
293
+ keys_vec.append(key)
294
+ exist_option = eic.ExistOption()
295
+ _, _ = self.connection.mexist(keys_vec, exist_option)
296
+ logger.info(
297
+ f"finish eic client warm up, warm up cost {time.perf_counter() - start_time:.2f} seconds"
298
+ )
299
+
300
+ def register_mem_pool_host(self, memory_pool_host: HostKVCache) -> None:
301
+ # no need judge meminfo type, cuda_id, etc.
302
+ meminfo = eic.MemoryInfo()
303
+ meminfo.type = eic.MemoryType.MEMORY_CUDA
304
+ meminfo.cuda_id = 0
305
+ vals = eic.IOBuffers()
306
+ buffer = memory_pool_host.kv_buffer
307
+ vals.append(
308
+ buffer.data_ptr(),
309
+ buffer.numel() * buffer.element_size(),
310
+ True,
311
+ )
312
+ self.connection.register_memory(vals, meminfo)
313
+
314
+ def _init_eic_prefix(self):
315
+ if self.is_mla_model:
316
+ self.eic_prefix = (
317
+ f"{self.model_name}_mla_att_{self.host_kvcache_layout}@sglang"
318
+ )
319
+ else:
320
+ self.eic_prefix = f"{self.model_name}_mha_attn_{self.host_kvcache_layout}_{self.rank}_{self.world_size}_@sglang"
321
+
322
+ def _get_eic_key(self, keys: List[str]) -> str:
323
+ return [f"{self.eic_prefix}_{key}" for key in keys]
324
+
325
+ def set(
326
+ self,
327
+ key: str,
328
+ value: Optional[Any] = None,
329
+ target_location: Optional[Any] = None,
330
+ target_size: Optional[Any] = None,
331
+ ) -> bool:
332
+ # now is not used
333
+ if self.use_zero_copy:
334
+ return self.zero_copy_batch_set([key], [target_location])
335
+ else:
336
+ return self.generic_batch_set([key], [value])
337
+
338
+ # target_locations and target_sizes are not used for now
339
+ def batch_set(
340
+ self,
341
+ keys: List[str],
342
+ values: Optional[Any] = None,
343
+ target_locations: Optional[Any] = None,
344
+ target_sizes: Optional[Any] = None,
345
+ ) -> bool:
346
+ if len(keys) == 0:
347
+ return True
348
+ if self.use_zero_copy:
349
+ return self.zero_copy_batch_set(keys, values)
350
+ else:
351
+ return self.generic_batch_set(keys, values)
352
+
353
+ def get(
354
+ self,
355
+ key,
356
+ target_location: Optional[Any] = None,
357
+ target_size: Optional[Any] = None,
358
+ ) -> torch.Tensor | None:
359
+ # now is not used
360
+ if self.use_zero_copy:
361
+ return self.zero_copy_batch_get([key], [target_location])
362
+ else:
363
+ return self.generic_batch_get([key], [target_location])
364
+
365
+ # use for v1 interface, and shound not be called directly
366
+ def batch_get(
367
+ self,
368
+ keys: List[str],
369
+ target_locations: Optional[Any] = None,
370
+ target_sizes: Optional[Any] = None,
371
+ ) -> List[torch.Tensor | None]:
372
+ assert len(keys) == len(target_locations)
373
+ if len(keys) == 0:
374
+ return None
375
+ if self.use_zero_copy:
376
+ return self.zero_copy_batch_get(keys, target_locations)
377
+ else:
378
+ return self.generic_batch_get(keys, target_locations)
379
+
380
+ def _batch_exists_impl(self, keys) -> List[bool]:
381
+ if len(keys) == 0:
382
+ return 0
383
+ eic_keys = self._get_eic_key(keys)
384
+ logger.debug(f"eic exists {len(keys)}")
385
+ result = []
386
+ exist_bs = 1024
387
+ for i in range(0, len(eic_keys), exist_bs):
388
+ batch_keys = eic_keys[i : i + exist_bs]
389
+ keys_vec = eic.StringVector()
390
+ for key in batch_keys:
391
+ keys_vec.append(key)
392
+ exist_option = eic.ExistOption()
393
+ exist_option.ns = self.eic_namespace
394
+ status_code, exist_outcome = self.connection.mexist(keys_vec, exist_option)
395
+ if status_code != eic.StatusCode.SUCCESS:
396
+ logger.error(
397
+ f"eic exists {len(keys)} failed, status_code {status_code}"
398
+ )
399
+ result.extend([False] * len(batch_keys))
400
+ for err_code in exist_outcome.status_codes:
401
+ result.append(err_code == eic.StatusCode.SUCCESS)
402
+ return result
403
+
404
+ def exists(self, key) -> bool:
405
+ exist_num = self.batch_exists([key])
406
+ return exist_num == 1
407
+
408
+ def batch_exists(
409
+ self, keys, extra_info: Optional[HiCacheStorageExtraInfo] = None
410
+ ) -> int:
411
+ if len(keys) == 0:
412
+ return 0
413
+ if self.use_zero_copy and not self.is_mla_model:
414
+ keys = self._get_mha_zero_copy_keys(keys)
415
+ exist_mask = self._batch_exists_impl(keys)
416
+ prefix_success = 0
417
+ for exist in exist_mask:
418
+ if exist:
419
+ prefix_success += 1
420
+ else:
421
+ break
422
+ if not self.is_mla_model and self.use_zero_copy:
423
+ prefix_success = prefix_success // 2
424
+ return prefix_success
425
+
426
+ def delete(self, key) -> None:
427
+ eic_keys = self._get_eic_key([key])
428
+ keys_vec = eic.StringVector()
429
+ for eic_key in eic_keys:
430
+ keys_vec.append(eic_key)
431
+ del_option = eic.DelOption()
432
+ self.connection.mdel(keys_vec, del_option)
433
+
434
+ def clear(self) -> None:
435
+ return
436
+
437
+ # Not used for now
438
+ def _filter_kv_cache(self, total_len) -> Tuple[int, int]:
439
+ mean_len = total_len // self.world_size
440
+ remainder = total_len % self.world_size
441
+ tp_keys_len = mean_len + (1 if self.rank < remainder else 0)
442
+ start = self.rank * mean_len + min(self.rank, remainder)
443
+ end = start + tp_keys_len
444
+ logger.debug(f"start: {start}, end: {end}, tp_keys_len: {tp_keys_len}")
445
+ return start, end
446
+
447
+ def zero_copy_batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
448
+ logger.debug(f"eic zero copy set {len(keys)} keys")
449
+ if len(keys) == 0:
450
+ return True
451
+ eic_keys = self._get_eic_key(keys)
452
+ keys_vec = eic.StringVector()
453
+ vals_vec = eic.IOBuffers()
454
+ # set data key & value
455
+ for i, key in enumerate(eic_keys):
456
+ # set data key & value
457
+ keys_vec.append(key)
458
+ vals_vec.append(
459
+ values[i].data_ptr(),
460
+ values[i].element_size() * values[i].numel(),
461
+ True,
462
+ )
463
+ # set options
464
+ set_option = eic.SetOption()
465
+ set_option.ns = self.eic_namespace
466
+ set_option.ttl_second = -1
467
+ status_code, set_outcome = self.connection.mset(keys_vec, vals_vec, set_option)
468
+ if status_code != eic.StatusCode.SUCCESS:
469
+ logger.error(f"eic mset {len(keys)} failed, status_code {status_code}")
470
+ return [False] * len(keys)
471
+ else:
472
+ logger.debug(f"eic zero copy mset {len(keys)} success")
473
+ return [True] * len(keys)
474
+
475
+ def zero_copy_batch_get(
476
+ self, keys: List[str], values: List[torch.Tensor]
477
+ ) -> List[bool]:
478
+ logger.debug(f"eic zero copy get {len(keys)} keys")
479
+ # Get Data: generate data keys and vals
480
+ get_data_start_time = time.perf_counter()
481
+ eic_keys = self._get_eic_key(keys)
482
+ data_keys = eic.StringVector()
483
+ data_vals = eic.IOBuffers()
484
+ success_mask = [True] * len(keys)
485
+ count = len(keys)
486
+ for i, key in enumerate(eic_keys):
487
+ data_keys.append(key)
488
+ data_vals.append(
489
+ values[i].data_ptr(),
490
+ values[i].element_size() * values[i].numel(),
491
+ True,
492
+ )
493
+
494
+ # Get data: recv data buffer tensor
495
+ get_option = eic.GetOption()
496
+ get_option.ns = self.eic_namespace
497
+ status_code, data_vals, get_outcome = self.connection.mget(
498
+ data_keys, get_option, data_vals
499
+ )
500
+
501
+ if status_code != eic.StatusCode.SUCCESS:
502
+ if status_code == eic.StatusCode.PARTIAL_FAILED:
503
+ for i, err_code in enumerate(get_outcome.status_codes):
504
+ success = err_code == eic.StatusCode.SUCCESS
505
+ if success:
506
+ logger.debug(f"eic get data {eic_keys[i]} success")
507
+ else:
508
+ logger.error(
509
+ f"eic get data {eic_keys[i]} failed, err_code {err_code}"
510
+ )
511
+ success_mask[i] = False
512
+ else:
513
+ logger.error(
514
+ f"eic mget {len(eic_keys)} keys failed, status_code {status_code}"
515
+ )
516
+ success_mask = [False] * len(keys)
517
+ return success_mask
518
+
519
+ get_data_end_time = time.perf_counter()
520
+ get_data_execution_time = (get_data_end_time - get_data_start_time) * 1e6
521
+ logger.debug(f"eic get {count} keys data cost %.2f us", get_data_execution_time)
522
+ return success_mask
523
+
524
+ def generic_batch_set(
525
+ self,
526
+ keys: List[str],
527
+ values: List[torch.Tensor],
528
+ ) -> List[bool]:
529
+ assert len(keys) == len(values)
530
+ logger.debug(f"eic generic set {len(keys)} keys")
531
+ if len(keys) == 0:
532
+ return True
533
+ eic_keys = self._get_eic_key(keys)
534
+ keys_vec = eic.StringVector()
535
+ vals_vec = eic.IOBuffers()
536
+ count = len(keys)
537
+ registered = False
538
+ items = []
539
+ if self.enable_kv_set_direct:
540
+ values_data_ptrs = []
541
+ items = self.kv_cache_write_mem_pool.try_allocate_kv_cache(
542
+ self.kv_cache_shape, self.kv_cache_dtype, count
543
+ )
544
+ if items is None:
545
+ logger.warning("can not allocate tensor from pool")
546
+ for i, value in enumerate(values):
547
+ values_data_ptrs.append(
548
+ (value.data_ptr(), value.element_size() * value.numel(), False)
549
+ )
550
+ else:
551
+ objs = items
552
+ registered = True
553
+ for i, key in enumerate(eic_keys):
554
+ temp = objs[i].reshape(values[i].shape).contiguous()
555
+ temp.copy_(values[i])
556
+ if temp.data_ptr() != objs[i].data_ptr():
557
+ registered = False
558
+ temp = temp.cpu()
559
+ values_data_ptrs.append(
560
+ (
561
+ temp.data_ptr(),
562
+ temp.element_size() * temp.numel(),
563
+ registered,
564
+ )
565
+ )
566
+
567
+ for i, key in enumerate(eic_keys):
568
+ keys_vec.append(key)
569
+ data_ptr, data_size, registered = values_data_ptrs[i]
570
+ vals_vec.append(data_ptr, data_size, registered)
571
+ else:
572
+ # use tensor direct
573
+ for i, key in enumerate(eic_keys):
574
+ keys_vec.append(key)
575
+ vals_vec.append(
576
+ values[i].data_ptr(),
577
+ values[i].element_size() * values[i].numel(),
578
+ False,
579
+ )
580
+
581
+ # set options
582
+ set_option = eic.SetOption()
583
+ set_option.ns = self.eic_namespace
584
+ set_option.ttl_second = -1
585
+ status_code, set_outcome = self.connection.mset(keys_vec, vals_vec, set_option)
586
+ if status_code != eic.StatusCode.SUCCESS:
587
+ logger.error(f"eic mset {len(eic_keys)} failed, status_code {status_code}")
588
+ else:
589
+ logger.debug(f"eic mset {len(eic_keys)} success")
590
+
591
+ if self.enable_kv_set_direct and items is not None:
592
+ for item in items:
593
+ self.kv_cache_write_mem_pool.free_to_mempool(item.data_ptr())
594
+
595
+ err_code = set_outcome.status_codes[0]
596
+ if err_code != eic.StatusCode.SUCCESS:
597
+ logger.error(f"set data key {len(eic_keys)} failed, err_code {err_code}")
598
+ return [False] * len(keys)
599
+
600
+ logger.debug(f"set data key {len(eic_keys)} success")
601
+ return [True] * len(keys)
602
+
603
+ def generic_batch_get(
604
+ self, keys: List[str], buffers: List[torch.Tensor]
605
+ ) -> List[bool]:
606
+ # all success or all fail
607
+ logger.debug(f"eic generic get {len(keys)} keys")
608
+ eic_keys = self._get_eic_key(keys)
609
+ get_data_start_time = time.perf_counter()
610
+ data_keys = eic.StringVector()
611
+ data_vals = eic.IOBuffers()
612
+ count = len(eic_keys)
613
+ registered = False
614
+ items = []
615
+ success_mask = [True] * len(keys)
616
+ if self.enable_kv_get_direct:
617
+ items = self.kv_cache_get_mem_pool.try_allocate_kv_cache(
618
+ self.kv_cache_shape, self.kv_cache_dtype, count
619
+ )
620
+ if items is None:
621
+ logger.warning("can not allocate tensor from pool")
622
+ for i, key in enumerate(eic_keys):
623
+ data_keys.append(key)
624
+ data_vals.append(
625
+ buffers[i].data_ptr(),
626
+ buffers[i].element_size() * buffers[i].numel(),
627
+ False,
628
+ )
629
+ else:
630
+ registered = True
631
+ for i, key in enumerate(eic_keys):
632
+ data_keys.append(key)
633
+ data_vals.append(
634
+ items[i].data_ptr(),
635
+ items[i].element_size() * items[i].numel(),
636
+ registered,
637
+ )
638
+
639
+ else:
640
+ for i, key in enumerate(eic_keys):
641
+ data_keys.append(key)
642
+ data_vals.append(
643
+ buffers[i].data_ptr(),
644
+ buffers[i].element_size() * buffers[i].numel(),
645
+ False,
646
+ )
647
+
648
+ # Get data: recv data buffer tensor
649
+ get_option = eic.GetOption()
650
+ get_option.ns = self.eic_namespace
651
+ status_code, data_vals, get_outcome = self.connection.mget(
652
+ data_keys, get_option, data_vals
653
+ )
654
+
655
+ if status_code != eic.StatusCode.SUCCESS:
656
+ if status_code == eic.StatusCode.PARTIAL_FAILED:
657
+ for i, err_code in enumerate(get_outcome.status_codes):
658
+ success = err_code == eic.StatusCode.SUCCESS
659
+ if success:
660
+ logger.debug(f"eic get data {eic_keys[i]} success")
661
+ else:
662
+ logger.error(
663
+ f"eic get data {eic_keys[i]} failed, err_code {err_code}"
664
+ )
665
+ success_mask[i] = False
666
+ else:
667
+ logger.error(
668
+ f"eic mget {len(eic_keys)} keys failed, status_code {status_code}"
669
+ )
670
+ success_mask = [False] * len(keys)
671
+
672
+ if registered:
673
+ for i, item in enumerate(items):
674
+ if success_mask[i]:
675
+ buffers[i].copy_(item)
676
+ self.kv_cache_get_mem_pool.free_to_mempool(item.data_ptr())
677
+
678
+ get_data_end_time = time.perf_counter()
679
+ get_data_execution_time = (get_data_end_time - get_data_start_time) * 1e6
680
+ logger.debug(f"eic get {count} keys data cost %.2f us", get_data_execution_time)
681
+ return success_mask
682
+
683
+ def _get_mha_zero_copy_keys(self, keys: List[str]) -> List[str]:
684
+ new_keys = []
685
+ for k in keys:
686
+ new_keys.append(f"{k}_k")
687
+ new_keys.append(f"{k}_v")
688
+ return new_keys
689
+
690
+ def _get_mha_zero_copy_values(
691
+ self, values: List[torch.Tensor]
692
+ ) -> List[torch.Tensor]:
693
+ new_values = []
694
+ for value in values:
695
+ new_values.append(value[0])
696
+ new_values.append(value[1])
697
+ return new_values
698
+
699
+ def _batch_get_preprocess(self, keys, host_indices):
700
+ page_num = len(host_indices) // self.page_size
701
+ # use memory pool directly or dummy page
702
+ values = (
703
+ [
704
+ self.memory_pool_host.get_data_page(
705
+ host_indices[i * self.page_size], flat=False
706
+ )
707
+ for i in range(page_num)
708
+ ]
709
+ if self.use_zero_copy
710
+ else [
711
+ self.memory_pool_host.get_dummy_flat_data_page()
712
+ for _ in range(page_num)
713
+ ]
714
+ )
715
+
716
+ if self.use_zero_copy and not self.is_mla_model:
717
+ keys = self._get_mha_zero_copy_keys(keys)
718
+ values = self._get_mha_zero_copy_values(values)
719
+
720
+ return keys, values
721
+
722
+ def _batch_get_postprocess(self, host_indices, values, results):
723
+ page_num = len(host_indices) // self.page_size
724
+
725
+ if self.use_zero_copy:
726
+ if not self.is_mla_model:
727
+ results = [
728
+ (results[2 * i] and results[2 * i + 1]) for i in range(page_num)
729
+ ]
730
+ results = results[:page_num]
731
+ return results
732
+
733
+ # dummy page copy to host memory pool
734
+ for i in range(page_num):
735
+ if not results[i]:
736
+ break
737
+ self.memory_pool_host.set_from_flat_data_page(
738
+ host_indices[i * self.memory_pool_host.page_size], values[i]
739
+ )
740
+
741
+ return results
742
+
743
+ def batch_get_v1(
744
+ self,
745
+ keys: List[str],
746
+ host_indices: torch.Tensor,
747
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
748
+ ) -> List[bool]:
749
+ keys, values = self._batch_get_preprocess(keys, host_indices)
750
+ results = self.batch_get(keys, values)
751
+ return self._batch_get_postprocess(host_indices, values, results)
752
+
753
+ def _batch_set_preprocess(self, keys, host_indices):
754
+ page_num = len(host_indices) // self.page_size
755
+ flat = not self.use_zero_copy
756
+ values = [
757
+ self.memory_pool_host.get_data_page(
758
+ host_indices[i * self.page_size], flat=flat
759
+ )
760
+ for i in range(page_num)
761
+ ]
762
+
763
+ if self.use_zero_copy and not self.is_mla_model:
764
+ keys = self._get_mha_zero_copy_keys(keys)
765
+ values = self._get_mha_zero_copy_values(values)
766
+
767
+ return keys, values
768
+
769
+ def batch_set_v1(
770
+ self,
771
+ keys: List[str],
772
+ host_indices: torch.Tensor,
773
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
774
+ ) -> List[bool]:
775
+ keys, values = self._batch_set_preprocess(keys, host_indices)
776
+ results = self.batch_set(keys, values)
777
+ return results