sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
1
1
  import abc
2
2
  import logging
3
3
  import threading
4
- from enum import IntEnum
5
4
  from functools import wraps
6
5
  from typing import Optional
7
6
 
@@ -31,27 +30,13 @@ if not (_is_npu or _is_xpu):
31
30
  logger = logging.getLogger(__name__)
32
31
 
33
32
 
34
- class MemoryStateInt(IntEnum):
35
- IDLE = 0
36
- RESERVED = 1
37
- PROTECTED = 2
38
- SYNCED = 3
39
- BACKUP = 4
33
+ def synchronized(func):
34
+ @wraps(func)
35
+ def wrapper(self, *args, **kwargs):
36
+ with self.lock:
37
+ return func(self, *args, **kwargs)
40
38
 
41
-
42
- def synchronized(debug_only=False):
43
- def _decorator(func):
44
- @wraps(func)
45
- def wrapper(self, *args, **kwargs):
46
- if (not debug_only) or self.debug:
47
- with self.lock:
48
- return func(self, *args, **kwargs)
49
- else:
50
- return True
51
-
52
- return wrapper
53
-
54
- return _decorator
39
+ return wrapper
55
40
 
56
41
 
57
42
  class HostKVCache(abc.ABC):
@@ -110,7 +95,6 @@ class HostKVCache(abc.ABC):
110
95
 
111
96
  # A lock for synchronized operations on memory allocation and state transitions.
112
97
  self.lock = threading.RLock()
113
- self.debug = logger.isEnabledFor(logging.DEBUG)
114
98
  self.clear()
115
99
 
116
100
  @abc.abstractmethod
@@ -140,7 +124,7 @@ class HostKVCache(abc.ABC):
140
124
  raise NotImplementedError()
141
125
 
142
126
  @abc.abstractmethod
143
- def get_flat_data_page(self, index) -> torch.Tensor:
127
+ def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
144
128
  """
145
129
  Get a flat data page from the host memory pool.
146
130
  """
@@ -161,7 +145,7 @@ class HostKVCache(abc.ABC):
161
145
  """
162
146
  raise NotImplementedError()
163
147
 
164
- @synchronized()
148
+ @synchronized
165
149
  def clear(self):
166
150
  # Initialize memory states and tracking structures.
167
151
  self.mem_state = torch.zeros(
@@ -172,7 +156,7 @@ class HostKVCache(abc.ABC):
172
156
  def available_size(self):
173
157
  return len(self.free_slots)
174
158
 
175
- @synchronized()
159
+ @synchronized
176
160
  def alloc(self, need_size: int) -> Optional[torch.Tensor]:
177
161
  assert (
178
162
  need_size % self.page_size == 0
@@ -183,92 +167,13 @@ class HostKVCache(abc.ABC):
183
167
  select_index = self.free_slots[:need_size]
184
168
  self.free_slots = self.free_slots[need_size:]
185
169
 
186
- if self.debug:
187
- self.mem_state[select_index] = MemoryStateInt.RESERVED
188
-
189
170
  return select_index
190
171
 
191
- @synchronized()
172
+ @synchronized
192
173
  def free(self, indices: torch.Tensor) -> int:
193
174
  self.free_slots = torch.cat([self.free_slots, indices])
194
- if self.debug:
195
- self.mem_state[indices] = MemoryStateInt.IDLE
196
175
  return len(indices)
197
176
 
198
- @synchronized(debug_only=True)
199
- def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
200
- assert len(indices) > 0, "The indices should not be empty"
201
- states = self.mem_state[indices]
202
- assert (
203
- states == states[0]
204
- ).all(), "The memory slots should have the same state {}".format(states)
205
- return MemoryStateInt(states[0].item())
206
-
207
- @synchronized(debug_only=True)
208
- def is_reserved(self, indices: torch.Tensor) -> bool:
209
- return self.get_state(indices) == MemoryStateInt.RESERVED
210
-
211
- @synchronized(debug_only=True)
212
- def is_protected(self, indices: torch.Tensor) -> bool:
213
- return self.get_state(indices) == MemoryStateInt.PROTECTED
214
-
215
- @synchronized(debug_only=True)
216
- def is_synced(self, indices: torch.Tensor) -> bool:
217
- return self.get_state(indices) == MemoryStateInt.SYNCED
218
-
219
- @synchronized(debug_only=True)
220
- def is_backup(self, indices: torch.Tensor) -> bool:
221
- return self.get_state(indices) == MemoryStateInt.BACKUP
222
-
223
- @synchronized(debug_only=True)
224
- def update_backup(self, indices: torch.Tensor):
225
- if not self.is_synced(indices):
226
- raise ValueError(
227
- f"The host memory slots should be in SYNCED state before turning into BACKUP. "
228
- f"Current state: {self.get_state(indices)}"
229
- )
230
- self.mem_state[indices] = MemoryStateInt.BACKUP
231
-
232
- @synchronized(debug_only=True)
233
- def update_prefetch(self, indices: torch.Tensor):
234
- if not self.is_reserved(indices):
235
- raise ValueError(
236
- f"The host memory slots should be in RESERVED state before turning into BACKUP. "
237
- f"Current state: {self.get_state(indices)}"
238
- )
239
- self.mem_state[indices] = MemoryStateInt.BACKUP
240
-
241
- @synchronized(debug_only=True)
242
- def update_synced(self, indices: torch.Tensor):
243
- self.mem_state[indices] = MemoryStateInt.SYNCED
244
-
245
- @synchronized(debug_only=True)
246
- def protect_write(self, indices: torch.Tensor):
247
- if not self.is_reserved(indices):
248
- raise ValueError(
249
- f"The host memory slots should be RESERVED before write operations. "
250
- f"Current state: {self.get_state(indices)}"
251
- )
252
- self.mem_state[indices] = MemoryStateInt.PROTECTED
253
-
254
- @synchronized(debug_only=True)
255
- def protect_load(self, indices: torch.Tensor):
256
- if not self.is_backup(indices):
257
- raise ValueError(
258
- f"The host memory slots should be in BACKUP state before load operations. "
259
- f"Current state: {self.get_state(indices)}"
260
- )
261
- self.mem_state[indices] = MemoryStateInt.PROTECTED
262
-
263
- @synchronized(debug_only=True)
264
- def complete_io(self, indices: torch.Tensor):
265
- if not self.is_protected(indices):
266
- raise ValueError(
267
- f"The host memory slots should be PROTECTED during I/O operations. "
268
- f"Current state: {self.get_state(indices)}"
269
- )
270
- self.mem_state[indices] = MemoryStateInt.SYNCED
271
-
272
177
 
273
178
  class MHATokenToKVPoolHost(HostKVCache):
274
179
  device_pool: MHATokenToKVPool
@@ -461,13 +366,19 @@ class MHATokenToKVPoolHost(HostKVCache):
461
366
  else:
462
367
  raise ValueError(f"Unsupported IO backend: {io_backend}")
463
368
 
464
- def get_flat_data_page(self, index) -> torch.Tensor:
369
+ def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
465
370
  if self.layout == "layer_first":
466
- return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
371
+ data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :]
467
372
  elif self.layout == "page_first":
468
- return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
373
+ data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :]
374
+ elif self.layout == "page_first_direct":
375
+ real_index = index // self.page_size
376
+ data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :]
469
377
  else:
470
378
  raise ValueError(f"Unsupported layout: {self.layout}")
379
+ if flat:
380
+ data_page = data_page.flatten()
381
+ return data_page
471
382
 
472
383
  def get_dummy_flat_data_page(self) -> torch.Tensor:
473
384
  return torch.zeros(
@@ -494,12 +405,22 @@ class MHATokenToKVPoolHost(HostKVCache):
494
405
  2, self.page_size, self.layer_num, self.head_num, self.head_dim
495
406
  )
496
407
  )
408
+ elif self.layout == "page_first_direct":
409
+ real_index = index // self.page_size
410
+ self.kv_buffer[:, real_index : real_index + 1, :, :, :, :] = (
411
+ data_page.reshape(
412
+ 2, 1, self.layer_num, self.page_size, self.head_num, self.head_dim
413
+ )
414
+ )
497
415
  else:
498
416
  raise ValueError(f"Unsupported layout: {self.layout}")
499
417
 
500
- def get_buffer_meta(self, keys, indices, local_rank):
418
+ def get_page_buffer_meta(self, indices):
419
+ """ "
420
+ meta data for zero copy
421
+ """
422
+ assert len(indices) % self.page_size == 0
501
423
  ptr_list = []
502
- key_list = []
503
424
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
504
425
  indices = indices.tolist()
505
426
  v_offset = (
@@ -509,48 +430,52 @@ class MHATokenToKVPoolHost(HostKVCache):
509
430
  * self.head_dim
510
431
  * self.dtype.itemsize
511
432
  )
512
- for index in range(0, len(indices), self.page_size):
513
- k_ptr = (
514
- kv_buffer_data_ptr
515
- + indices[index]
516
- * self.layer_num
433
+ if self.layout == "layer_first":
434
+ for index in range(0, len(indices), self.page_size):
435
+ for layer_id in range(self.layer_num):
436
+ k_ptr = (
437
+ kv_buffer_data_ptr
438
+ + indices[index]
439
+ * self.head_num
440
+ * self.head_dim
441
+ * self.dtype.itemsize
442
+ + layer_id
443
+ * self.size
444
+ * self.head_num
445
+ * self.head_dim
446
+ * self.dtype.itemsize
447
+ )
448
+ v_ptr = k_ptr + v_offset
449
+ ptr_list.append(k_ptr)
450
+ ptr_list.append(v_ptr)
451
+ element_size = (
452
+ self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
453
+ )
454
+ element_size_list = [element_size] * len(ptr_list)
455
+ elif self.layout in ["page_first", "page_first_direct"]:
456
+ for index in range(0, len(indices), self.page_size):
457
+ k_ptr = (
458
+ kv_buffer_data_ptr
459
+ + indices[index]
460
+ * self.layer_num
461
+ * self.head_num
462
+ * self.head_dim
463
+ * self.dtype.itemsize
464
+ )
465
+ v_ptr = k_ptr + v_offset
466
+ ptr_list.append(k_ptr)
467
+ ptr_list.append(v_ptr)
468
+ element_size = (
469
+ self.layer_num
470
+ * self.dtype.itemsize
471
+ * self.page_size
517
472
  * self.head_num
518
473
  * self.head_dim
519
- * self.dtype.itemsize
520
474
  )
521
- v_ptr = k_ptr + v_offset
522
- ptr_list.append(k_ptr)
523
- ptr_list.append(v_ptr)
524
- key_ = keys[index // self.page_size]
525
- key_list.append(f"{key_}_{local_rank}_k")
526
- key_list.append(f"{key_}_{local_rank}_v")
527
- element_size = (
528
- self.layer_num
529
- * self.dtype.itemsize
530
- * self.page_size
531
- * self.head_num
532
- * self.head_dim
533
- )
534
- element_size_list = [element_size] * len(key_list)
535
- return key_list, ptr_list, element_size_list
536
-
537
- def get_buffer_with_hash(self, keys, indices=None):
538
- assert self.layout == "page_first"
539
- assert indices is None or (len(keys) == (len(indices) // self.page_size))
540
-
541
- key_list = []
542
- buf_list = []
543
-
544
- for i in range(len(keys)):
545
- key = keys[i]
546
- key_list.append(f"{key}-k")
547
- key_list.append(f"{key}-v")
548
- if indices is not None:
549
- index = indices[i * self.page_size]
550
- buf_list.append(self.k_buffer[index : index + self.page_size])
551
- buf_list.append(self.v_buffer[index : index + self.page_size])
552
-
553
- return key_list, buf_list, 2
475
+ element_size_list = [element_size] * len(ptr_list)
476
+ else:
477
+ raise ValueError(f"Unsupported layout: {self.layout}")
478
+ return ptr_list, element_size_list
554
479
 
555
480
 
556
481
  class MLATokenToKVPoolHost(HostKVCache):
@@ -726,13 +651,19 @@ class MLATokenToKVPoolHost(HostKVCache):
726
651
  else:
727
652
  raise ValueError(f"Unsupported IO backend: {io_backend}")
728
653
 
729
- def get_flat_data_page(self, index) -> torch.Tensor:
654
+ def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
730
655
  if self.layout == "layer_first":
731
- return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
656
+ data_page = self.kv_buffer[:, index : index + self.page_size, :, :]
732
657
  elif self.layout == "page_first":
733
- return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
658
+ data_page = self.kv_buffer[index : index + self.page_size, :, :, :]
659
+ elif self.layout == "page_first_direct":
660
+ real_index = index // self.page_size
661
+ data_page = self.kv_buffer[real_index : real_index + 1, :, :, :, :]
734
662
  else:
735
663
  raise ValueError(f"Unsupported layout: {self.layout}")
664
+ if flat:
665
+ data_page = data_page.flatten()
666
+ return data_page
736
667
 
737
668
  def get_dummy_flat_data_page(self) -> torch.Tensor:
738
669
  return torch.zeros(
@@ -762,43 +693,63 @@ class MLATokenToKVPoolHost(HostKVCache):
762
693
  1,
763
694
  self.kv_lora_rank + self.qk_rope_head_dim,
764
695
  )
696
+ elif self.layout == "page_first_direct":
697
+ real_index = index // self.page_size
698
+ self.kv_buffer[real_index : real_index + 1, :, :, :, :] = data_page.reshape(
699
+ 1,
700
+ self.layer_num,
701
+ self.page_size,
702
+ 1,
703
+ self.kv_lora_rank + self.qk_rope_head_dim,
704
+ )
765
705
  else:
766
706
  raise ValueError(f"Unsupported layout: {self.layout}")
767
707
 
768
- def get_buffer_meta(self, keys, indices, local_rank):
708
+ def get_page_buffer_meta(self, indices):
709
+ """ "
710
+ meta data for zero copy
711
+ """
712
+ assert len(indices) % self.page_size == 0
769
713
  ptr_list = []
770
- key_list = []
771
714
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
772
715
  indices = indices.tolist()
773
- for index in range(0, len(indices), self.page_size):
774
- k_ptr = (
775
- kv_buffer_data_ptr
776
- + indices[index]
777
- * self.layer_num
716
+ if self.layout == "layer_first":
717
+ for index in range(0, len(indices), self.page_size):
718
+ for layer_id in range(self.layer_num):
719
+ k_ptr = (
720
+ kv_buffer_data_ptr
721
+ + indices[index]
722
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
723
+ * self.dtype.itemsize
724
+ + layer_id
725
+ * self.size
726
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
727
+ * self.dtype.itemsize
728
+ )
729
+ ptr_list.append(k_ptr)
730
+ element_size = (
731
+ self.dtype.itemsize
732
+ * self.page_size
778
733
  * (self.kv_lora_rank + self.qk_rope_head_dim)
734
+ )
735
+ element_size_list = [element_size] * len(ptr_list)
736
+ elif self.layout in ["page_first", "page_first_direct"]:
737
+ for index in range(0, len(indices), self.page_size):
738
+ k_ptr = (
739
+ kv_buffer_data_ptr
740
+ + indices[index]
741
+ * self.layer_num
742
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
743
+ * self.dtype.itemsize
744
+ )
745
+ ptr_list.append(k_ptr)
746
+ element_size = (
747
+ self.layer_num
779
748
  * self.dtype.itemsize
749
+ * self.page_size
750
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
780
751
  )
781
- ptr_list.append(k_ptr)
782
- key_ = keys[index // self.page_size]
783
- key_list.append(f"{key_}_k")
784
- element_size = (
785
- self.layer_num
786
- * self.dtype.itemsize
787
- * self.page_size
788
- * (self.kv_lora_rank + self.qk_rope_head_dim)
789
- )
790
- element_size_list = [element_size] * len(key_list)
791
- return key_list, ptr_list, element_size_list
792
-
793
- def get_buffer_with_hash(self, keys, indices=None):
794
- assert self.layout == "page_first"
795
- assert indices is None or (len(keys) == (len(indices) // self.page_size))
796
-
797
- buf_list = []
798
-
799
- if indices is not None:
800
- for i in range(len(keys)):
801
- index = indices[i * self.page_size]
802
- buf_list.append(self.kv_buffer[index : index + self.page_size])
803
-
804
- return keys, buf_list, 1
752
+ element_size_list = [element_size] * len(ptr_list)
753
+ else:
754
+ raise ValueError(f"Unsupported layout: {self.layout}")
755
+ return ptr_list, element_size_list
@@ -1,6 +1,5 @@
1
1
  import logging
2
2
  from collections import OrderedDict
3
- from typing import Dict
4
3
 
5
4
  import torch
6
5