sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,778 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ import time
5
+ import uuid
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import eic
10
+ import torch
11
+ import yaml
12
+
13
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
14
+ from sglang.srt.mem_cache.hicache_storage import (
15
+ HiCacheStorage,
16
+ HiCacheStorageConfig,
17
+ HiCacheStorageExtraInfo,
18
+ )
19
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache, MLATokenToKVPoolHost
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ TensorPoolSize = 2048
25
+
26
+ REMOTE_EIC_YAML_ENV_VAR = "REMOTE_EIC_YAML"
27
+
28
+ # gpu direct rdma for kv set
29
+ G_EnableKVSetGPUDirect = False
30
+
31
+ # gpu direct rdma for kv get
32
+ G_EnableKVGetGPUDirect = False
33
+
34
+ # gpu nic affinity
35
+ G_EnableGPUNicAffinity = False
36
+
37
+ # default H20 gpu nic affinity
38
+ GPUNicAffinity = {
39
+ "cuda:0": "eth1",
40
+ "cuda:1": "eth1",
41
+ "cuda:2": "eth2",
42
+ "cuda:3": "eth2",
43
+ "cuda:4": "eth3",
44
+ "cuda:5": "eth3",
45
+ "cuda:6": "eth4",
46
+ "cuda:7": "eth4",
47
+ }
48
+
49
+ # default H20 cpu nic affinity
50
+ CPUNicAffinity = {
51
+ "cuda:0": "cpu",
52
+ "cuda:1": "cpu",
53
+ "cuda:2": "cpu",
54
+ "cuda:3": "cpu",
55
+ "cuda:4": "cpu",
56
+ "cuda:5": "cpu",
57
+ "cuda:6": "cpu",
58
+ "cuda:7": "cpu",
59
+ }
60
+
61
+
62
+ def get_eic_config_file_path():
63
+ if os.environ.get(REMOTE_EIC_YAML_ENV_VAR) is not None:
64
+ logger.info(f"eic init with env var {REMOTE_EIC_YAML_ENV_VAR}")
65
+ config_file = os.environ.get(REMOTE_EIC_YAML_ENV_VAR)
66
+ else:
67
+ config_file = "/sgl-workspace/config/remote-eic.yaml"
68
+ logger.info(f"eic init with default config, config_file {config_file}")
69
+ return config_file
70
+
71
+
72
+ class FlexibleKVCacheMemoryPool:
73
+ def __init__(self, conn, kvcache_shape, kvcache_dtype, device):
74
+ self.connection = conn
75
+
76
+ if device.startswith("cpu") and G_EnableGPUNicAffinity:
77
+ gpu_id = torch.cuda.current_device()
78
+ self.device = CPUNicAffinity["cuda:" + str(gpu_id)]
79
+ # current memory pool size is 5 times of CPU TensorPoolSize
80
+ mempool_size = TensorPoolSize * 5
81
+ else:
82
+ self.device = device
83
+ mempool_size = TensorPoolSize
84
+
85
+ self.kvcache_shape = kvcache_shape
86
+ self.kvcache_dtype = kvcache_dtype
87
+
88
+ self.kv_cache_numel = 1
89
+ for i in self.kvcache_shape:
90
+ self.kv_cache_numel *= i
91
+
92
+ self.free_data_addr = set()
93
+ self.data_ptr_to_index = dict()
94
+
95
+ if self.device.startswith("cpu"):
96
+ self.kvcache_mempool = torch.zeros(
97
+ (mempool_size,) + kvcache_shape,
98
+ dtype=kvcache_dtype,
99
+ device=self.device,
100
+ pin_memory=True,
101
+ )
102
+ else:
103
+ self.kvcache_mempool = torch.zeros(
104
+ (mempool_size,) + kvcache_shape, dtype=kvcache_dtype, device=self.device
105
+ )
106
+
107
+ for i in range(mempool_size):
108
+ self.free_data_addr.add(i)
109
+ self.data_ptr_to_index[self.kvcache_mempool[i].data_ptr()] = i
110
+
111
+ meminfo = eic.MemoryInfo()
112
+ meminfo.type = eic.MemoryType.MEMORY_CUDA
113
+ meminfo.cuda_id = 0
114
+ vals = eic.IOBuffers()
115
+ vals.append(
116
+ self.kvcache_mempool.data_ptr(),
117
+ self.kvcache_mempool.numel() * self.kvcache_mempool.element_size(),
118
+ True,
119
+ )
120
+ self.connection.register_memory(vals, meminfo)
121
+ logger.info(
122
+ f"allocate memory pool, size {self.kvcache_mempool.numel() * self.kvcache_mempool.element_size()}, device {self.device}"
123
+ )
124
+
125
+ def try_allocate_kv_cache(self, shape, dtype, count=1):
126
+ if len(self.free_data_addr) < count:
127
+ return None
128
+
129
+ numel = 1
130
+ for i in shape:
131
+ numel *= i
132
+ if numel != self.kv_cache_numel or dtype != self.kvcache_dtype:
133
+ logger.error(
134
+ f"allocate from mempool failed, self.kvcache_shape {self.kvcache_shape}, dtype {self.kvcache_dtype}, require shape {shape}, dtype {dtype}"
135
+ )
136
+ return None
137
+
138
+ ret = []
139
+ for _ in range(count):
140
+ free_index = self.free_data_addr.pop()
141
+ ret.append(self.kvcache_mempool[free_index])
142
+ return ret
143
+
144
+ def free_to_mempool(self, data_ptr):
145
+ if data_ptr not in self.data_ptr_to_index:
146
+ logger.error(
147
+ f"free_to_mempool failed, data_ptr {data_ptr} not in allocated_data_addr"
148
+ )
149
+ return
150
+ self.free_data_addr.add(self.data_ptr_to_index[data_ptr])
151
+
152
+ def check_data_ptr_allocated(self, data_ptr):
153
+ return data_ptr in self.data_ptr_to_index
154
+
155
+ def left_count(self):
156
+ return len(self.free_data_addr)
157
+
158
+
159
+ class EICStorage(HiCacheStorage):
160
+ def __init__(
161
+ self, hicache_config: HiCacheStorageConfig, memory_pool_host: HostKVCache
162
+ ):
163
+ global G_EnableKVSetGPUDirect, G_EnableKVGetGPUDirect
164
+ global GPUNicAffinity, CPUNicAffinity, G_EnableGPUNicAffinity
165
+
166
+ config_file = get_eic_config_file_path()
167
+ if os.path.exists(config_file) is False:
168
+ logger.error(f"config file {config_file} not exists")
169
+ raise RuntimeError(f"eic config file {config_file} not exists")
170
+
171
+ with open(config_file, "r") as fin:
172
+ config = yaml.safe_load(fin)
173
+
174
+ remote_url = config.get("remote_url", None)
175
+ if remote_url is None:
176
+ AssertionError("remote_url is None")
177
+
178
+ endpoint = remote_url[len("eic://") :]
179
+
180
+ logger.info(f"eic remote_url:" + remote_url + " endpoint: " + endpoint)
181
+
182
+ eic_instance_id = config.get("eic_instance_id", None)
183
+ logger.info(f"eic instance_id: {eic_instance_id}")
184
+
185
+ eic_thread_num = config.get("eic_thread_num", 1)
186
+ logger.info(f"eic thread_num: {eic_thread_num}")
187
+
188
+ eic_log_dir = config.get("eic_log_dir", None)
189
+ logger.info(f"eic log_dir: {eic_log_dir}")
190
+
191
+ eic_log_level = config.get("eic_log_level", 2)
192
+ logger.info(f"eic log_level: {eic_log_level}")
193
+
194
+ eic_trans_type = config.get("eic_trans_type", 3)
195
+ logger.info(f"eic trans_type: {eic_trans_type}")
196
+
197
+ eic_flag_file = config.get("eic_flag_file", None)
198
+ logger.info(f"eic flag_file: {eic_flag_file}")
199
+
200
+ # GDR now is not used
201
+ G_EnableKVSetGPUDirect = (
202
+ config.get("enable_kvset_gpu_direct", False) and torch.cuda.is_available()
203
+ )
204
+ logger.debug(f"eic enable_kvset_gpu_direct: {G_EnableKVSetGPUDirect}")
205
+
206
+ G_EnableKVGetGPUDirect = (
207
+ config.get("enable_kvget_gpu_direct", False) and torch.cuda.is_available()
208
+ )
209
+ logger.debug(f"eic enable_kvget_gpu_direct: {G_EnableKVGetGPUDirect}")
210
+
211
+ self.model_name = hicache_config.model_name
212
+
213
+ # rdma
214
+ enable_kv_set_direct = config.get("enable_kvset_direct", True)
215
+ logger.info(f"eic enable_kv_set_direct: {enable_kv_set_direct}")
216
+ self.enable_kv_set_direct = enable_kv_set_direct
217
+
218
+ enable_kv_get_direct = config.get("enable_kvget_direct", True)
219
+ logger.info(f"eic enable_kv_get_direct: {enable_kv_get_direct}")
220
+ self.enable_kv_get_direct = enable_kv_get_direct
221
+
222
+ # gpu nic affinity
223
+ G_EnableGPUNicAffinity = config.get("enable_gpu_nic_affinity", False)
224
+ logger.info(f"eic enable_gpu_nic_affinity: {G_EnableGPUNicAffinity}")
225
+ self.enable_gpu_nic_affinity = G_EnableGPUNicAffinity
226
+
227
+ if G_EnableGPUNicAffinity:
228
+ if "gpu_nic_affinity_config" in config:
229
+ GPUNicAffinity = json.loads(config["gpu_nic_affinity_config"])
230
+ if "cpu_nic_affinity_config" in config:
231
+ CPUNicAffinity = json.loads(config["cpu_nic_affinity_config"])
232
+ logger.info(f"eic gpu nic affinity {GPUNicAffinity}")
233
+ logger.info(f"eic cpu nic affinity {CPUNicAffinity}")
234
+
235
+ eic_namespace = config.get("eic_namespace", "")
236
+ logger.info(f"eic namespace: {eic_namespace}")
237
+ self.eic_namespace = eic_namespace
238
+
239
+ if not os.path.exists(eic_log_dir) and not os.path.isdir(eic_log_dir):
240
+ os.makedirs(eic_log_dir, exist_ok=True)
241
+
242
+ self.connection = eic.Client()
243
+ init_option = eic.InitOption()
244
+ init_option.log_dir = eic_log_dir
245
+ init_option.log_level = eic.LogLevel(eic_log_level)
246
+ init_option.transport_type = eic.TransportType(eic_trans_type)
247
+ init_option.flag_file = eic_flag_file
248
+
249
+ if G_EnableGPUNicAffinity:
250
+ gpu_id = torch.cuda.current_device()
251
+ init_option.multi_net_local_interface_names = GPUNicAffinity[
252
+ "cuda:" + str(gpu_id)
253
+ ]
254
+ logger.info(
255
+ f"gpu {gpu_id} set gpu nic affinity to {init_option.multi_net_local_interface_names}"
256
+ )
257
+
258
+ ret = self.connection.init(eic_instance_id, endpoint, init_option)
259
+ if ret != 0:
260
+ logger.error(f"fail to init eic client, ret: {ret}")
261
+ raise RuntimeError("EIC Client Init Failed.")
262
+ self.warmup()
263
+
264
+ self.memory_pool_host = memory_pool_host
265
+ self.host_kvcache_layout = self.memory_pool_host.layout
266
+ self.trans_type = eic.TransportType(eic_trans_type)
267
+ self.kv_cache_dtype = self.memory_pool_host.dtype
268
+ self.is_mla_model = hicache_config.is_mla_model
269
+ self.rank = hicache_config.tp_rank
270
+ self.world_size = hicache_config.tp_size
271
+ self.page_size = self.memory_pool_host.page_size
272
+ self.use_zero_copy = self.memory_pool_host.layout == "page_first"
273
+ if not self.use_zero_copy:
274
+ self.kv_cache_shape = self.memory_pool_host.get_data_page(
275
+ 0, flat=True
276
+ ).shape
277
+ if self.enable_kv_set_direct:
278
+ self.kv_cache_write_mem_pool = FlexibleKVCacheMemoryPool(
279
+ self.connection, self.kv_cache_shape, self.kv_cache_dtype, "cpu"
280
+ )
281
+ if self.enable_kv_get_direct:
282
+ self.kv_cache_get_mem_pool = FlexibleKVCacheMemoryPool(
283
+ self.connection, self.kv_cache_shape, self.kv_cache_dtype, "cpu"
284
+ )
285
+ self._init_eic_prefix()
286
+
287
+ def warmup(self):
288
+ logger.info("begin warm up eic client")
289
+ start_time = time.perf_counter()
290
+ num_warmup = 1024
291
+ preheat_keys = ["warmup_key_" + str(i) for i in range(num_warmup)]
292
+ batch_size = 32
293
+ for i in range(0, num_warmup, batch_size):
294
+ keys_vec = eic.StringVector()
295
+ for key in preheat_keys[i : i + batch_size]:
296
+ keys_vec.append(key)
297
+ exist_option = eic.ExistOption()
298
+ _, _ = self.connection.mexist(keys_vec, exist_option)
299
+ logger.info(
300
+ f"finish eic client warm up, warm up cost {time.perf_counter() - start_time:.2f} seconds"
301
+ )
302
+
303
+ def register_mem_pool_host(self, memory_pool_host: HostKVCache) -> None:
304
+ # no need judge meminfo type, cuda_id, etc.
305
+ meminfo = eic.MemoryInfo()
306
+ meminfo.type = eic.MemoryType.MEMORY_CUDA
307
+ meminfo.cuda_id = 0
308
+ vals = eic.IOBuffers()
309
+ buffer = memory_pool_host.kv_buffer
310
+ vals.append(
311
+ buffer.data_ptr(),
312
+ buffer.numel() * buffer.element_size(),
313
+ True,
314
+ )
315
+ self.connection.register_memory(vals, meminfo)
316
+
317
+ def _init_eic_prefix(self):
318
+ if self.is_mla_model:
319
+ self.eic_prefix = (
320
+ f"{self.model_name}_mla_att_{self.host_kvcache_layout}@sglang"
321
+ )
322
+ else:
323
+ self.eic_prefix = f"{self.model_name}_mha_attn_{self.host_kvcache_layout}_{self.rank}_{self.world_size}_@sglang"
324
+
325
+ def _get_eic_key(self, keys: List[str]) -> str:
326
+ return [f"{self.eic_prefix}_{key}" for key in keys]
327
+
328
+ def set(
329
+ self,
330
+ key: str,
331
+ value: Optional[Any] = None,
332
+ target_location: Optional[Any] = None,
333
+ target_size: Optional[Any] = None,
334
+ ) -> bool:
335
+ # now is not used
336
+ if self.use_zero_copy:
337
+ return self.zero_copy_batch_set([key], [target_location])
338
+ else:
339
+ return self.generic_batch_set([key], [value])
340
+
341
+ # target_locations and target_sizes are not used for now
342
+ def batch_set(
343
+ self,
344
+ keys: List[str],
345
+ values: Optional[Any] = None,
346
+ target_locations: Optional[Any] = None,
347
+ target_sizes: Optional[Any] = None,
348
+ ) -> bool:
349
+ if len(keys) == 0:
350
+ return True
351
+ if self.use_zero_copy:
352
+ return self.zero_copy_batch_set(keys, values)
353
+ else:
354
+ return self.generic_batch_set(keys, values)
355
+
356
+ def get(
357
+ self,
358
+ key,
359
+ target_location: Optional[Any] = None,
360
+ target_size: Optional[Any] = None,
361
+ ) -> torch.Tensor | None:
362
+ # now is not used
363
+ if self.use_zero_copy:
364
+ return self.zero_copy_batch_get([key], [target_location])
365
+ else:
366
+ return self.generic_batch_get([key], [target_location])
367
+
368
+ # use for v1 interface, and shound not be called directly
369
+ def batch_get(
370
+ self,
371
+ keys: List[str],
372
+ target_locations: Optional[Any] = None,
373
+ target_sizes: Optional[Any] = None,
374
+ ) -> List[torch.Tensor | None]:
375
+ assert len(keys) == len(target_locations)
376
+ if len(keys) == 0:
377
+ return None
378
+ if self.use_zero_copy:
379
+ return self.zero_copy_batch_get(keys, target_locations)
380
+ else:
381
+ return self.generic_batch_get(keys, target_locations)
382
+
383
+ def _batch_exists_impl(self, keys) -> List[bool]:
384
+ if len(keys) == 0:
385
+ return 0
386
+ eic_keys = self._get_eic_key(keys)
387
+ logger.debug(f"eic exists {len(keys)}")
388
+ result = []
389
+ exist_bs = 1024
390
+ for i in range(0, len(eic_keys), exist_bs):
391
+ batch_keys = eic_keys[i : i + exist_bs]
392
+ keys_vec = eic.StringVector()
393
+ for key in batch_keys:
394
+ keys_vec.append(key)
395
+ exist_option = eic.ExistOption()
396
+ exist_option.ns = self.eic_namespace
397
+ status_code, exist_outcome = self.connection.mexist(keys_vec, exist_option)
398
+ if status_code != eic.StatusCode.SUCCESS:
399
+ logger.error(
400
+ f"eic exists {len(keys)} failed, status_code {status_code}"
401
+ )
402
+ result.extend([False] * len(batch_keys))
403
+ for err_code in exist_outcome.status_codes:
404
+ result.append(err_code == eic.StatusCode.SUCCESS)
405
+ return result
406
+
407
+ def exists(self, key) -> bool:
408
+ exist_num = self.batch_exists([key])
409
+ return exist_num == 1
410
+
411
+ def batch_exists(self, keys) -> int:
412
+ if len(keys) == 0:
413
+ return 0
414
+ if self.use_zero_copy and not self.is_mla_model:
415
+ keys = self._get_mha_zero_copy_keys(keys)
416
+ exist_mask = self._batch_exists_impl(keys)
417
+ prefix_success = 0
418
+ for exist in exist_mask:
419
+ if exist:
420
+ prefix_success += 1
421
+ else:
422
+ break
423
+ if not self.is_mla_model and self.use_zero_copy:
424
+ prefix_success = prefix_success // 2
425
+ return prefix_success
426
+
427
+ def delete(self, key) -> None:
428
+ eic_keys = self._get_eic_key([key])
429
+ keys_vec = eic.StringVector()
430
+ for eic_key in eic_keys:
431
+ keys_vec.append(eic_key)
432
+ del_option = eic.DelOption()
433
+ self.connection.mdel(keys_vec, del_option)
434
+
435
+ def clear(self) -> None:
436
+ return
437
+
438
+ # Not used for now
439
+ def _filter_kv_cache(self, total_len) -> Tuple[int, int]:
440
+ mean_len = total_len // self.world_size
441
+ remainder = total_len % self.world_size
442
+ tp_keys_len = mean_len + (1 if self.rank < remainder else 0)
443
+ start = self.rank * mean_len + min(self.rank, remainder)
444
+ end = start + tp_keys_len
445
+ logger.debug(f"start: {start}, end: {end}, tp_keys_len: {tp_keys_len}")
446
+ return start, end
447
+
448
+ def zero_copy_batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
449
+ logger.debug(f"eic zero copy set {len(keys)} keys")
450
+ if len(keys) == 0:
451
+ return True
452
+ eic_keys = self._get_eic_key(keys)
453
+ keys_vec = eic.StringVector()
454
+ vals_vec = eic.IOBuffers()
455
+ # set data key & value
456
+ for i, key in enumerate(eic_keys):
457
+ # set data key & value
458
+ keys_vec.append(key)
459
+ vals_vec.append(
460
+ values[i].data_ptr(),
461
+ values[i].element_size() * values[i].numel(),
462
+ True,
463
+ )
464
+ # set options
465
+ set_option = eic.SetOption()
466
+ set_option.ns = self.eic_namespace
467
+ set_option.ttl_second = -1
468
+ status_code, set_outcome = self.connection.mset(keys_vec, vals_vec, set_option)
469
+ if status_code != eic.StatusCode.SUCCESS:
470
+ logger.error(f"eic mset {len(keys)} failed, status_code {status_code}")
471
+ return [False] * len(keys)
472
+ else:
473
+ logger.debug(f"eic zero copy mset {len(keys)} success")
474
+ return [True] * len(keys)
475
+
476
+ def zero_copy_batch_get(
477
+ self, keys: List[str], values: List[torch.Tensor]
478
+ ) -> List[bool]:
479
+ logger.debug(f"eic zero copy get {len(keys)} keys")
480
+ # Get Data: generate data keys and vals
481
+ get_data_start_time = time.perf_counter()
482
+ eic_keys = self._get_eic_key(keys)
483
+ data_keys = eic.StringVector()
484
+ data_vals = eic.IOBuffers()
485
+ success_mask = [True] * len(keys)
486
+ count = len(keys)
487
+ for i, key in enumerate(eic_keys):
488
+ data_keys.append(key)
489
+ data_vals.append(
490
+ values[i].data_ptr(),
491
+ values[i].element_size() * values[i].numel(),
492
+ True,
493
+ )
494
+
495
+ # Get data: recv data buffer tensor
496
+ get_option = eic.GetOption()
497
+ get_option.ns = self.eic_namespace
498
+ status_code, data_vals, get_outcome = self.connection.mget(
499
+ data_keys, get_option, data_vals
500
+ )
501
+
502
+ if status_code != eic.StatusCode.SUCCESS:
503
+ if status_code == eic.StatusCode.PARTIAL_FAILED:
504
+ for i, err_code in enumerate(get_outcome.status_codes):
505
+ success = err_code == eic.StatusCode.SUCCESS
506
+ if success:
507
+ logger.debug(f"eic get data {eic_keys[i]} success")
508
+ else:
509
+ logger.error(
510
+ f"eic get data {eic_keys[i]} failed, err_code {err_code}"
511
+ )
512
+ success_mask[i] = False
513
+ else:
514
+ logger.error(
515
+ f"eic mget {len(eic_keys)} keys failed, status_code {status_code}"
516
+ )
517
+ success_mask = [False] * len(keys)
518
+ return success_mask
519
+
520
+ get_data_end_time = time.perf_counter()
521
+ get_data_execution_time = (get_data_end_time - get_data_start_time) * 1e6
522
+ logger.debug(f"eic get {count} keys data cost %.2f us", get_data_execution_time)
523
+ return success_mask
524
+
525
+ def generic_batch_set(
526
+ self,
527
+ keys: List[str],
528
+ values: List[torch.Tensor],
529
+ ) -> List[bool]:
530
+ assert len(keys) == len(values)
531
+ logger.debug(f"eic generic set {len(keys)} keys")
532
+ if len(keys) == 0:
533
+ return True
534
+ eic_keys = self._get_eic_key(keys)
535
+ keys_vec = eic.StringVector()
536
+ vals_vec = eic.IOBuffers()
537
+ count = len(keys)
538
+ registered = False
539
+ items = []
540
+ if self.enable_kv_set_direct:
541
+ values_data_ptrs = []
542
+ items = self.kv_cache_write_mem_pool.try_allocate_kv_cache(
543
+ self.kv_cache_shape, self.kv_cache_dtype, count
544
+ )
545
+ if items is None:
546
+ logger.warning("can not allocate tensor from pool")
547
+ for i, value in enumerate(values):
548
+ values_data_ptrs.append(
549
+ (value.data_ptr(), value.element_size() * value.numel(), False)
550
+ )
551
+ else:
552
+ objs = items
553
+ registered = True
554
+ for i, key in enumerate(eic_keys):
555
+ temp = objs[i].reshape(values[i].shape).contiguous()
556
+ temp.copy_(values[i])
557
+ if temp.data_ptr() != objs[i].data_ptr():
558
+ registered = False
559
+ temp = temp.cpu()
560
+ values_data_ptrs.append(
561
+ (
562
+ temp.data_ptr(),
563
+ temp.element_size() * temp.numel(),
564
+ registered,
565
+ )
566
+ )
567
+
568
+ for i, key in enumerate(eic_keys):
569
+ keys_vec.append(key)
570
+ data_ptr, data_size, registered = values_data_ptrs[i]
571
+ vals_vec.append(data_ptr, data_size, registered)
572
+ else:
573
+ # use tensor direct
574
+ for i, key in enumerate(eic_keys):
575
+ keys_vec.append(key)
576
+ vals_vec.append(
577
+ values[i].data_ptr(),
578
+ values[i].element_size() * values[i].numel(),
579
+ False,
580
+ )
581
+
582
+ # set options
583
+ set_option = eic.SetOption()
584
+ set_option.ns = self.eic_namespace
585
+ set_option.ttl_second = -1
586
+ status_code, set_outcome = self.connection.mset(keys_vec, vals_vec, set_option)
587
+ if status_code != eic.StatusCode.SUCCESS:
588
+ logger.error(f"eic mset {len(eic_keys)} failed, status_code {status_code}")
589
+ else:
590
+ logger.debug(f"eic mset {len(eic_keys)} success")
591
+
592
+ if self.enable_kv_set_direct and items is not None:
593
+ for item in items:
594
+ self.kv_cache_write_mem_pool.free_to_mempool(item.data_ptr())
595
+
596
+ err_code = set_outcome.status_codes[0]
597
+ if err_code != eic.StatusCode.SUCCESS:
598
+ logger.error(f"set data key {len(eic_keys)} failed, err_code {err_code}")
599
+ return [False] * len(keys)
600
+
601
+ logger.debug(f"set data key {len(eic_keys)} success")
602
+ return [True] * len(keys)
603
+
604
+ def generic_batch_get(
605
+ self, keys: List[str], buffers: List[torch.Tensor]
606
+ ) -> List[bool]:
607
+ # all success or all fail
608
+ logger.debug(f"eic generic get {len(keys)} keys")
609
+ eic_keys = self._get_eic_key(keys)
610
+ get_data_start_time = time.perf_counter()
611
+ data_keys = eic.StringVector()
612
+ data_vals = eic.IOBuffers()
613
+ count = len(eic_keys)
614
+ registered = False
615
+ items = []
616
+ success_mask = [True] * len(keys)
617
+ if self.enable_kv_get_direct:
618
+ items = self.kv_cache_get_mem_pool.try_allocate_kv_cache(
619
+ self.kv_cache_shape, self.kv_cache_dtype, count
620
+ )
621
+ if items is None:
622
+ logger.warning("can not allocate tensor from pool")
623
+ for i, key in enumerate(eic_keys):
624
+ data_keys.append(key)
625
+ data_vals.append(
626
+ buffers[i].data_ptr(),
627
+ buffers[i].element_size() * buffers[i].numel(),
628
+ False,
629
+ )
630
+ else:
631
+ registered = True
632
+ for i, key in enumerate(eic_keys):
633
+ data_keys.append(key)
634
+ data_vals.append(
635
+ items[i].data_ptr(),
636
+ items[i].element_size() * items[i].numel(),
637
+ registered,
638
+ )
639
+
640
+ else:
641
+ for i, key in enumerate(eic_keys):
642
+ data_keys.append(key)
643
+ data_vals.append(
644
+ buffers[i].data_ptr(),
645
+ buffers[i].element_size() * buffers[i].numel(),
646
+ False,
647
+ )
648
+
649
+ # Get data: recv data buffer tensor
650
+ get_option = eic.GetOption()
651
+ get_option.ns = self.eic_namespace
652
+ status_code, data_vals, get_outcome = self.connection.mget(
653
+ data_keys, get_option, data_vals
654
+ )
655
+
656
+ if status_code != eic.StatusCode.SUCCESS:
657
+ if status_code == eic.StatusCode.PARTIAL_FAILED:
658
+ for i, err_code in enumerate(get_outcome.status_codes):
659
+ success = err_code == eic.StatusCode.SUCCESS
660
+ if success:
661
+ logger.debug(f"eic get data {eic_keys[i]} success")
662
+ else:
663
+ logger.error(
664
+ f"eic get data {eic_keys[i]} failed, err_code {err_code}"
665
+ )
666
+ success_mask[i] = False
667
+ else:
668
+ logger.error(
669
+ f"eic mget {len(eic_keys)} keys failed, status_code {status_code}"
670
+ )
671
+ success_mask = [False] * len(keys)
672
+
673
+ if registered:
674
+ for i, item in enumerate(items):
675
+ if success_mask[i]:
676
+ buffers[i].copy_(item)
677
+ self.kv_cache_get_mem_pool.free_to_mempool(item.data_ptr())
678
+
679
+ get_data_end_time = time.perf_counter()
680
+ get_data_execution_time = (get_data_end_time - get_data_start_time) * 1e6
681
+ logger.debug(f"eic get {count} keys data cost %.2f us", get_data_execution_time)
682
+ return success_mask
683
+
684
+ def _get_mha_zero_copy_keys(self, keys: List[str]) -> List[str]:
685
+ new_keys = []
686
+ for k in keys:
687
+ new_keys.append(f"{k}_k")
688
+ new_keys.append(f"{k}_v")
689
+ return new_keys
690
+
691
+ def _get_mha_zero_copy_values(
692
+ self, values: List[torch.Tensor]
693
+ ) -> List[torch.Tensor]:
694
+ new_values = []
695
+ for value in values:
696
+ new_values.append(value[0])
697
+ new_values.append(value[1])
698
+ return new_values
699
+
700
+ def _batch_get_preprocess(self, keys, host_indices):
701
+ page_num = len(host_indices) // self.page_size
702
+ # use memory pool directly or dummy page
703
+ values = (
704
+ [
705
+ self.memory_pool_host.get_data_page(
706
+ host_indices[i * self.page_size], flat=False
707
+ )
708
+ for i in range(page_num)
709
+ ]
710
+ if self.use_zero_copy
711
+ else [
712
+ self.memory_pool_host.get_dummy_flat_data_page()
713
+ for _ in range(page_num)
714
+ ]
715
+ )
716
+
717
+ if self.use_zero_copy and not self.is_mla_model:
718
+ keys = self._get_mha_zero_copy_keys(keys)
719
+ values = self._get_mha_zero_copy_values(values)
720
+
721
+ return keys, values
722
+
723
+ def _batch_get_postprocess(self, host_indices, values, results):
724
+ page_num = len(host_indices) // self.page_size
725
+
726
+ if self.use_zero_copy:
727
+ if not self.is_mla_model:
728
+ results = [
729
+ (results[2 * i] and results[2 * i + 1]) for i in range(page_num)
730
+ ]
731
+ results = results[:page_num]
732
+ return results
733
+
734
+ # dummy page copy to host memory pool
735
+ for i in range(page_num):
736
+ if not results[i]:
737
+ break
738
+ self.memory_pool_host.set_from_flat_data_page(
739
+ host_indices[i * self.memory_pool_host.page_size], values[i]
740
+ )
741
+
742
+ return results
743
+
744
+ def batch_get_v1(
745
+ self,
746
+ keys: List[str],
747
+ host_indices: torch.Tensor,
748
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
749
+ ) -> List[bool]:
750
+ keys, values = self._batch_get_preprocess(keys, host_indices)
751
+ results = self.batch_get(keys, values)
752
+ return self._batch_get_postprocess(host_indices, values, results)
753
+
754
+ def _batch_set_preprocess(self, keys, host_indices):
755
+ page_num = len(host_indices) // self.page_size
756
+ flat = not self.use_zero_copy
757
+ values = [
758
+ self.memory_pool_host.get_data_page(
759
+ host_indices[i * self.page_size], flat=flat
760
+ )
761
+ for i in range(page_num)
762
+ ]
763
+
764
+ if self.use_zero_copy and not self.is_mla_model:
765
+ keys = self._get_mha_zero_copy_keys(keys)
766
+ values = self._get_mha_zero_copy_values(values)
767
+
768
+ return keys, values
769
+
770
+ def batch_set_v1(
771
+ self,
772
+ keys: List[str],
773
+ host_indices: torch.Tensor,
774
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
775
+ ) -> List[bool]:
776
+ keys, values = self._batch_set_preprocess(keys, host_indices)
777
+ results = self.batch_set(keys, values)
778
+ return results