sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -15,9 +15,12 @@ limitations under the License.
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ from dataclasses import dataclass
19
+
20
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams
18
21
  from sglang.srt.layers.attention.nsa import index_buf_accessor
19
22
  from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
20
- from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
23
+ from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
21
24
 
22
25
  """
23
26
  Memory pool.
@@ -44,6 +47,8 @@ from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
44
47
 
45
48
  if TYPE_CHECKING:
46
49
  from sglang.srt.managers.cache_controller import LayerDoneCounter
50
+ from sglang.srt.managers.schedule_batch import Req
51
+
47
52
 
48
53
  logger = logging.getLogger(__name__)
49
54
 
@@ -109,92 +114,135 @@ class ReqToTokenPool:
109
114
 
110
115
 
111
116
  class MambaPool:
117
+ @dataclass(frozen=True, kw_only=True)
118
+ class State:
119
+ conv: torch.Tensor
120
+ temporal: torch.Tensor
121
+
122
+ def at_layer_idx(self, layer: int):
123
+ return type(self)(**{k: v[layer] for k, v in vars(self).items()})
124
+
125
+ def mem_usage_bytes(self):
126
+ return sum(get_tensor_size_bytes(t) for t in vars(self).values())
127
+
128
+ @dataclass(frozen=True, kw_only=True)
129
+ class SpeculativeState(State):
130
+ intermediate_ssm: torch.Tensor
131
+ intermediate_conv_window: torch.Tensor
132
+
112
133
  def __init__(
113
134
  self,
135
+ *,
114
136
  size: int,
115
- conv_dtype: torch.dtype,
116
- ssm_dtype: torch.dtype,
117
- num_mamba_layers: int,
118
- conv_state_shape: Tuple[int, int],
119
- temporal_state_shape: Tuple[int, int],
137
+ cache_params: "Mamba2CacheParams",
120
138
  device: str,
121
139
  speculative_num_draft_tokens: Optional[int] = None,
122
140
  ):
123
- conv_state = torch.zeros(
124
- size=(num_mamba_layers, size + 1) + conv_state_shape,
125
- dtype=conv_dtype,
126
- device=device,
127
- )
128
- temporal_state = torch.zeros(
129
- size=(num_mamba_layers, size + 1) + temporal_state_shape,
130
- dtype=ssm_dtype,
131
- device=device,
141
+ conv_state_shape = cache_params.shape.conv
142
+ temporal_state_shape = cache_params.shape.temporal
143
+ conv_dtype = cache_params.dtype.conv
144
+ ssm_dtype = cache_params.dtype.temporal
145
+ num_mamba_layers = len(cache_params.layers)
146
+
147
+ # for disagg with nvlink
148
+ self.enable_custom_mem_pool = get_bool_env_var(
149
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
132
150
  )
133
- if speculative_num_draft_tokens is not None:
134
- # Cache intermediate SSM states per draft token during target verify
135
- # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
136
- intermediate_ssm_state_cache = torch.zeros(
137
- size=(
138
- num_mamba_layers,
139
- size + 1,
140
- speculative_num_draft_tokens,
141
- temporal_state_shape[0],
142
- temporal_state_shape[1],
143
- temporal_state_shape[2],
144
- ),
145
- dtype=ssm_dtype,
146
- device="cuda",
147
- )
148
- # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
149
- # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
150
- intermediate_conv_window_cache = torch.zeros(
151
- size=(
152
- num_mamba_layers,
153
- size + 1,
154
- speculative_num_draft_tokens,
155
- conv_state_shape[0],
156
- conv_state_shape[1],
157
- ),
151
+ if self.enable_custom_mem_pool:
152
+ # TODO(shangming): abstract custom allocator class for more backends
153
+ from mooncake.allocator import NVLinkAllocator
154
+
155
+ allocator = NVLinkAllocator.get_allocator(self.device)
156
+ self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
157
+ else:
158
+ self.custom_mem_pool = None
159
+
160
+ with (
161
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
162
+ if self.enable_custom_mem_pool
163
+ else nullcontext()
164
+ ):
165
+ # assume conv_state = (dim, state_len)
166
+ assert conv_state_shape[0] > conv_state_shape[1]
167
+ conv_state = torch.zeros(
168
+ size=(num_mamba_layers, size + 1) + conv_state_shape,
158
169
  dtype=conv_dtype,
159
- device="cuda",
160
- )
161
- self.mamba_cache = (
162
- conv_state,
163
- temporal_state,
164
- intermediate_ssm_state_cache,
165
- intermediate_conv_window_cache,
170
+ device=device,
166
171
  )
167
- logger.info(
168
- f"Mamba Cache is allocated. "
169
- f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
170
- f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
171
- f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
172
- f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
172
+ temporal_state = torch.zeros(
173
+ size=(num_mamba_layers, size + 1) + temporal_state_shape,
174
+ dtype=ssm_dtype,
175
+ device=device,
173
176
  )
174
- else:
175
- self.mamba_cache = (conv_state, temporal_state)
176
- logger.info(
177
- f"Mamba Cache is allocated. "
178
- f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
179
- f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
177
+ if speculative_num_draft_tokens is not None:
178
+ # Cache intermediate SSM states per draft token during target verify
179
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
180
+ intermediate_ssm_state_cache = torch.zeros(
181
+ size=(
182
+ num_mamba_layers,
183
+ size + 1,
184
+ speculative_num_draft_tokens,
185
+ temporal_state_shape[0],
186
+ temporal_state_shape[1],
187
+ temporal_state_shape[2],
188
+ ),
189
+ dtype=ssm_dtype,
190
+ device="cuda",
191
+ )
192
+ # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
193
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
194
+ intermediate_conv_window_cache = torch.zeros(
195
+ size=(
196
+ num_mamba_layers,
197
+ size + 1,
198
+ speculative_num_draft_tokens,
199
+ conv_state_shape[0],
200
+ conv_state_shape[1],
201
+ ),
202
+ dtype=conv_dtype,
203
+ device="cuda",
204
+ )
205
+ self.mamba_cache = self.SpeculativeState(
206
+ conv=conv_state,
207
+ temporal=temporal_state,
208
+ intermediate_ssm=intermediate_ssm_state_cache,
209
+ intermediate_conv_window=intermediate_conv_window_cache,
210
+ )
211
+ logger.info(
212
+ f"Mamba Cache is allocated. "
213
+ f"max_mamba_cache_size: {size}, "
214
+ f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
215
+ f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
216
+ f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
217
+ f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
218
+ )
219
+ else:
220
+ self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
221
+ logger.info(
222
+ f"Mamba Cache is allocated. "
223
+ f"max_mamba_cache_size: {size}, "
224
+ f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
225
+ f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
226
+ )
227
+ self.size = size
228
+ self.device = device
229
+ self.free_slots = torch.arange(
230
+ self.size, dtype=torch.int64, device=self.device
180
231
  )
181
- self.size = size
182
- self.free_slots = list(range(size))
183
- self.mem_usage = self.get_mamba_size() / GB
184
-
185
- def get_mamba_params_all_layers(self):
186
- return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
232
+ self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
233
+ self.num_mamba_layers = num_mamba_layers
187
234
 
188
- def get_mamba_params(self, layer_id: int):
189
- return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
235
+ def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
236
+ assert isinstance(self.mamba_cache, self.SpeculativeState)
237
+ return self.mamba_cache
190
238
 
191
- def get_mamba_size(self):
192
- return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
239
+ def mamba2_layer_cache(self, layer_id: int):
240
+ return self.mamba_cache.at_layer_idx(layer_id)
193
241
 
194
242
  def available_size(self):
195
243
  return len(self.free_slots)
196
244
 
197
- def alloc(self, need_size: int) -> Optional[List[int]]:
245
+ def alloc(self, need_size: int) -> Optional[torch.Tensor]:
198
246
  if need_size > len(self.free_slots):
199
247
  return None
200
248
 
@@ -203,15 +251,46 @@ class MambaPool:
203
251
 
204
252
  return select_index
205
253
 
206
- def free(self, free_index: Union[int, List[int]]):
207
- if isinstance(free_index, (int,)):
208
- self.free_slots.append(free_index)
209
- else:
210
- self.free_slots.extend(free_index)
211
- self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
254
+ def free(self, free_index: torch.Tensor):
255
+ if free_index.numel() == 0:
256
+ return
257
+ self.free_slots = torch.cat((self.free_slots, free_index))
258
+ self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
259
+ :, free_index
260
+ ] = 0
212
261
 
213
262
  def clear(self):
214
- self.free_slots = list(range(self.size))
263
+ self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
264
+
265
+ def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
266
+ self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
267
+ self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
268
+ :, src_index
269
+ ]
270
+ return
271
+
272
+ def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]:
273
+ dst_index = self.alloc(1)
274
+ if dst_index == None:
275
+ return None
276
+ self.copy_from(src_index, dst_index)
277
+ return dst_index
278
+
279
+ def get_contiguous_buf_infos(self):
280
+ state_tensors = [
281
+ getattr(self.mamba_cache, field) for field in vars(self.mamba_cache)
282
+ ]
283
+ data_ptrs, data_lens, item_lens = [], [], []
284
+
285
+ for _, state_tensor in enumerate(state_tensors):
286
+ data_ptrs += [
287
+ state_tensor[i].data_ptr() for i in range(self.num_mamba_layers)
288
+ ]
289
+ data_lens += [state_tensor[i].nbytes for i in range(self.num_mamba_layers)]
290
+ item_lens += [
291
+ state_tensor[i][0].nbytes for i in range(self.num_mamba_layers)
292
+ ]
293
+ return data_ptrs, data_lens, item_lens
215
294
 
216
295
 
217
296
  class HybridReqToTokenPool(ReqToTokenPool):
@@ -219,16 +298,14 @@ class HybridReqToTokenPool(ReqToTokenPool):
219
298
 
220
299
  def __init__(
221
300
  self,
301
+ *,
222
302
  size: int,
303
+ mamba_size: int,
223
304
  max_context_len: int,
224
305
  device: str,
225
306
  enable_memory_saver: bool,
226
- conv_dtype: torch.dtype,
227
- ssm_dtype: torch.dtype,
228
- mamba_layers: List[int],
229
- conv_state_shape: Tuple[int, int],
230
- temporal_state_shape: Tuple[int, int],
231
- speculative_num_draft_tokens: int,
307
+ cache_params: "Mamba2CacheParams",
308
+ speculative_num_draft_tokens: int = None,
232
309
  ):
233
310
  super().__init__(
234
311
  size=size,
@@ -236,31 +313,37 @@ class HybridReqToTokenPool(ReqToTokenPool):
236
313
  device=device,
237
314
  enable_memory_saver=enable_memory_saver,
238
315
  )
316
+ self._init_mamba_pool(
317
+ size=mamba_size,
318
+ cache_params=cache_params,
319
+ device=device,
320
+ speculative_num_draft_tokens=speculative_num_draft_tokens,
321
+ )
239
322
 
323
+ def _init_mamba_pool(
324
+ self,
325
+ size: int,
326
+ cache_params: "Mamba2CacheParams",
327
+ device: str,
328
+ speculative_num_draft_tokens: int = None,
329
+ ):
240
330
  self.mamba_pool = MambaPool(
241
- size,
242
- conv_dtype,
243
- ssm_dtype,
244
- len(mamba_layers),
245
- conv_state_shape,
246
- temporal_state_shape,
247
- device,
248
- speculative_num_draft_tokens,
331
+ size=size,
332
+ cache_params=cache_params,
333
+ device=device,
334
+ speculative_num_draft_tokens=speculative_num_draft_tokens,
249
335
  )
250
- self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
336
+ self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
251
337
 
252
338
  self.device = device
253
339
  self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
254
340
  size, dtype=torch.int32, device=self.device
255
341
  )
256
342
 
257
- self.rid_to_mamba_index_mapping: Dict[str, int] = {}
258
- self.mamba_index_to_rid_mapping: Dict[int, str] = {}
259
-
260
343
  # For chunk prefill req, we do not need to allocate mamba cache,
261
344
  # We could use allocated mamba cache instead.
262
345
  def alloc(
263
- self, need_size: int, reqs: Optional[List["Req"]] = None
346
+ self, need_size: int, reqs: Optional[List[Req]] = None
264
347
  ) -> Optional[List[int]]:
265
348
  select_index = super().alloc(need_size)
266
349
  if select_index == None:
@@ -268,14 +351,14 @@ class HybridReqToTokenPool(ReqToTokenPool):
268
351
 
269
352
  mamba_index = []
270
353
  for req in reqs:
271
- rid = req.rid
272
- if rid in self.rid_to_mamba_index_mapping:
273
- mid = self.rid_to_mamba_index_mapping[rid]
274
- elif (mid := self.mamba_pool.alloc(1)) is not None:
275
- mid = mid[0]
276
- self.rid_to_mamba_index_mapping[rid] = mid
277
- self.mamba_index_to_rid_mapping[mid] = rid
278
- mamba_index.append(mid)
354
+ mid = None
355
+ if req.mamba_pool_idx is not None: # for radix cache
356
+ mid = req.mamba_pool_idx
357
+ else:
358
+ mid = self.mamba_pool.alloc(1)[0]
359
+ req.mamba_pool_idx = mid
360
+ if mid is not None:
361
+ mamba_index.append(mid)
279
362
  assert len(select_index) == len(
280
363
  mamba_index
281
364
  ), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
@@ -287,26 +370,21 @@ class HybridReqToTokenPool(ReqToTokenPool):
287
370
  def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
288
371
  return self.req_index_to_mamba_index_mapping[req_indices]
289
372
 
290
- def get_mamba_params(self, layer_id: int):
373
+ def mamba2_layer_cache(self, layer_id: int):
291
374
  assert layer_id in self.mamba_map
292
- return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
375
+ return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
293
376
 
294
- def get_mamba_params_all_layers(self):
295
- return self.mamba_pool.get_mamba_params_all_layers()
377
+ def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
378
+ return self.mamba_pool.get_speculative_mamba2_params_all_layers()
296
379
 
297
380
  # For chunk prefill, we can not free mamba cache, we need use it in the future
298
381
  def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
382
+ if isinstance(free_index, (int,)):
383
+ free_index = [free_index]
299
384
  super().free(free_index)
300
385
  if free_mamba_cache:
301
386
  mamba_index = self.req_index_to_mamba_index_mapping[free_index]
302
- mamba_index_list = mamba_index.tolist()
303
- if isinstance(mamba_index_list, int):
304
- mamba_index_list = [mamba_index_list]
305
- self.mamba_pool.free(mamba_index_list)
306
- for mid in mamba_index_list:
307
- rid = self.mamba_index_to_rid_mapping[mid]
308
- self.mamba_index_to_rid_mapping.pop(mid)
309
- self.rid_to_mamba_index_mapping.pop(rid)
387
+ self.mamba_pool.free(mamba_index)
310
388
 
311
389
  def clear(self):
312
390
  super().clear()
@@ -349,6 +427,19 @@ class KVCache(abc.ABC):
349
427
  # default state for optional layer-wise transfer control
350
428
  self.layer_transfer_counter = None
351
429
 
430
+ # for disagg with nvlink
431
+ self.enable_custom_mem_pool = get_bool_env_var(
432
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
433
+ )
434
+ if self.enable_custom_mem_pool:
435
+ # TODO(shangming): abstract custom allocator class for more backends
436
+ from mooncake.allocator import NVLinkAllocator
437
+
438
+ allocator = NVLinkAllocator.get_allocator(self.device)
439
+ self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
440
+ else:
441
+ self.custom_mem_pool = None
442
+
352
443
  def _finalize_allocation_log(self, num_tokens: int):
353
444
  """Common logging and mem_usage computation for KV cache allocation.
354
445
  Supports both tuple (K, V) size returns and single KV size returns.
@@ -400,6 +491,9 @@ class KVCache(abc.ABC):
400
491
  def load_cpu_copy(self, kv_cache_cpu, indices):
401
492
  raise NotImplementedError()
402
493
 
494
+ def maybe_get_custom_mem_pool(self):
495
+ return self.custom_mem_pool
496
+
403
497
 
404
498
  class MHATokenToKVPool(KVCache):
405
499
 
@@ -415,6 +509,7 @@ class MHATokenToKVPool(KVCache):
415
509
  enable_memory_saver: bool,
416
510
  start_layer: Optional[int] = None,
417
511
  end_layer: Optional[int] = None,
512
+ enable_kv_cache_copy: bool = False,
418
513
  ):
419
514
  super().__init__(
420
515
  size,
@@ -429,25 +524,61 @@ class MHATokenToKVPool(KVCache):
429
524
  self.head_num = head_num
430
525
  self.head_dim = head_dim
431
526
 
432
- # for disagg with nvlink
433
- self.enable_custom_mem_pool = get_bool_env_var(
434
- "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
435
- )
436
- if self.enable_custom_mem_pool:
437
- # TODO(shangming): abstract custom allocator class for more backends
438
- from mooncake.allocator import NVLinkAllocator
439
-
440
- allocator = NVLinkAllocator.get_allocator(self.device)
441
- self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
442
- else:
443
- self.custom_mem_pool = None
444
-
445
527
  self._create_buffers()
446
528
 
447
529
  self.device_module = torch.get_device_module(self.device)
448
530
  self.alt_stream = self.device_module.Stream() if _is_cuda else None
531
+
532
+ if enable_kv_cache_copy:
533
+ self._init_kv_copy_and_warmup()
534
+ else:
535
+ self._kv_copy_config = None
536
+
449
537
  self._finalize_allocation_log(size)
450
538
 
539
+ def _init_kv_copy_and_warmup(self):
540
+ # Heuristics for KV copy tiling
541
+ _KV_COPY_STRIDE_THRESHOLD_LARGE = 8192
542
+ _KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096
543
+ _KV_COPY_TILE_SIZE_LARGE = 512
544
+ _KV_COPY_TILE_SIZE_MEDIUM = 256
545
+ _KV_COPY_TILE_SIZE_SMALL = 128
546
+ _KV_COPY_NUM_WARPS_LARGE_TILE = 8
547
+ _KV_COPY_NUM_WARPS_SMALL_TILE = 4
548
+
549
+ stride_bytes = int(self.data_strides[0].item())
550
+ if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE:
551
+ bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE
552
+ elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM:
553
+ bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM
554
+ else:
555
+ bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL
556
+
557
+ self._kv_copy_config = {
558
+ "bytes_per_tile": bytes_per_tile,
559
+ "byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile,
560
+ "num_warps": (
561
+ _KV_COPY_NUM_WARPS_SMALL_TILE
562
+ if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM
563
+ else _KV_COPY_NUM_WARPS_LARGE_TILE
564
+ ),
565
+ }
566
+
567
+ dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device)
568
+ grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"])
569
+
570
+ copy_all_layer_kv_cache_tiled[grid](
571
+ self.data_ptrs,
572
+ self.data_strides,
573
+ dummy_loc,
574
+ dummy_loc,
575
+ 1,
576
+ 1,
577
+ BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"],
578
+ num_warps=self._kv_copy_config["num_warps"],
579
+ num_stages=2,
580
+ )
581
+
451
582
  def _create_buffers(self):
452
583
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
453
584
  with (
@@ -535,9 +666,6 @@ class MHATokenToKVPool(KVCache):
535
666
  ]
536
667
  return kv_data_ptrs, kv_data_lens, kv_item_lens
537
668
 
538
- def maybe_get_custom_mem_pool(self):
539
- return self.custom_mem_pool
540
-
541
669
  def get_cpu_copy(self, indices):
542
670
  torch.cuda.synchronize()
543
671
  kv_cache_cpu = []
@@ -642,13 +770,28 @@ class MHATokenToKVPool(KVCache):
642
770
  self.v_buffer[layer_id - self.start_layer][loc] = cache_v
643
771
 
644
772
  def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
645
- copy_all_layer_kv_cache[(len(self.data_ptrs),)](
773
+ N = tgt_loc.numel()
774
+ if N == 0:
775
+ return
776
+
777
+ assert (
778
+ self._kv_copy_config is not None
779
+ ), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__"
780
+
781
+ cfg = self._kv_copy_config
782
+ N_upper = next_power_of_2(N)
783
+ grid = (self.data_ptrs.numel(), cfg["byte_tiles"])
784
+
785
+ copy_all_layer_kv_cache_tiled[grid](
646
786
  self.data_ptrs,
647
787
  self.data_strides,
648
788
  tgt_loc,
649
789
  src_loc,
650
- len(tgt_loc),
651
- next_power_of_2(len(tgt_loc)),
790
+ N,
791
+ N_upper,
792
+ BYTES_PER_TILE=cfg["bytes_per_tile"],
793
+ num_warps=cfg["num_warps"],
794
+ num_stages=2,
652
795
  )
653
796
 
654
797
 
@@ -665,12 +808,18 @@ class HybridLinearKVPool(KVCache):
665
808
  full_attention_layer_ids: List[int],
666
809
  enable_kvcache_transpose: bool,
667
810
  device: str,
811
+ mamba_pool: MambaPool,
668
812
  ):
669
813
  self.size = size
670
814
  self.dtype = dtype
671
815
  self.device = device
672
816
  self.full_layer_nums = len(full_attention_layer_ids)
673
817
  self.page_size = page_size
818
+ # TODO support pp?
819
+ self.start_layer = 0
820
+ self.head_num = head_num
821
+ self.head_dim = head_dim
822
+ self.mamba_pool = mamba_pool
674
823
  # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
675
824
  assert not enable_kvcache_transpose
676
825
  if _is_npu:
@@ -699,6 +848,15 @@ class HybridLinearKVPool(KVCache):
699
848
  def get_contiguous_buf_infos(self):
700
849
  return self.full_kv_pool.get_contiguous_buf_infos()
701
850
 
851
+ def get_state_buf_infos(self):
852
+ mamba_data_ptrs, mamba_data_lens, mamba_item_lens = (
853
+ self.mamba_pool.get_contiguous_buf_infos()
854
+ )
855
+ return mamba_data_ptrs, mamba_data_lens, mamba_item_lens
856
+
857
+ def maybe_get_custom_mem_pool(self):
858
+ return self.full_kv_pool.maybe_get_custom_mem_pool()
859
+
702
860
  def _transfer_full_attention_id(self, layer_id: int):
703
861
  if layer_id not in self.full_attention_layer_id_mapping:
704
862
  raise ValueError(
@@ -749,28 +907,57 @@ class SWAKVPool(KVCache):
749
907
  self,
750
908
  size: int,
751
909
  size_swa: int,
910
+ dtype: torch.dtype,
911
+ head_num: int,
912
+ head_dim: int,
752
913
  swa_attention_layer_ids: List[int],
753
914
  full_attention_layer_ids: List[int],
754
915
  enable_kvcache_transpose: bool,
916
+ device: str,
755
917
  token_to_kv_pool_class: KVCache = MHATokenToKVPool,
756
918
  **kwargs,
757
919
  ):
758
920
  self.size = size
759
921
  self.size_swa = size_swa
922
+ self.dtype = dtype
923
+ self.head_num = head_num
924
+ self.head_dim = head_dim
925
+ self.device = device
760
926
  self.swa_layer_nums = len(swa_attention_layer_ids)
761
927
  self.full_layer_nums = len(full_attention_layer_ids)
928
+ self.start_layer = 0
929
+ self.page_size = 1
930
+
762
931
  kwargs["page_size"] = 1
763
932
  kwargs["enable_memory_saver"] = False
933
+ kwargs["head_num"] = head_num
934
+ kwargs["head_dim"] = head_dim
935
+ kwargs["device"] = device
764
936
  # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
765
937
  assert not enable_kvcache_transpose
766
938
 
939
+ # for disagg with nvlink
940
+ self.enable_custom_mem_pool = get_bool_env_var(
941
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
942
+ )
943
+ if self.enable_custom_mem_pool:
944
+ # TODO(shangming): abstract custom allocator class for more backends
945
+ from mooncake.allocator import NVLinkAllocator
946
+
947
+ allocator = NVLinkAllocator.get_allocator(self.device)
948
+ self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
949
+ else:
950
+ self.custom_mem_pool = None
951
+
767
952
  self.swa_kv_pool = token_to_kv_pool_class(
768
953
  size=size_swa,
954
+ dtype=dtype,
769
955
  layer_num=self.swa_layer_nums,
770
956
  **kwargs,
771
957
  )
772
958
  self.full_kv_pool = token_to_kv_pool_class(
773
959
  size=size,
960
+ dtype=dtype,
774
961
  layer_num=self.full_layer_nums,
775
962
  **kwargs,
776
963
  )
@@ -783,6 +970,9 @@ class SWAKVPool(KVCache):
783
970
 
784
971
  k_size, v_size = self.get_kv_size_bytes()
785
972
  self.mem_usage = (k_size + v_size) / GB
973
+ logger.info(
974
+ f"SWAKVPool mem usage: {self.mem_usage} GB, swa size: {self.size_swa}, full size: {self.size}"
975
+ )
786
976
 
787
977
  def get_kv_size_bytes(self):
788
978
  k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
@@ -793,15 +983,19 @@ class SWAKVPool(KVCache):
793
983
  full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
794
984
  self.full_kv_pool.get_contiguous_buf_infos()
795
985
  )
986
+
987
+ kv_data_ptrs = full_kv_data_ptrs
988
+ kv_data_lens = full_kv_data_lens
989
+ kv_item_lens = full_kv_item_lens
990
+
991
+ return kv_data_ptrs, kv_data_lens, kv_item_lens
992
+
993
+ def get_state_buf_infos(self):
796
994
  swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
797
995
  self.swa_kv_pool.get_contiguous_buf_infos()
798
996
  )
799
997
 
800
- kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
801
- kv_data_lens = full_kv_data_lens + swa_kv_data_lens
802
- kv_item_lens = full_kv_item_lens + swa_kv_item_lens
803
-
804
- return kv_data_ptrs, kv_data_lens, kv_item_lens
998
+ return swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens
805
999
 
806
1000
  def get_key_buffer(self, layer_id: int):
807
1001
  layer_id_pool, is_swa = self.layers_mapping[layer_id]
@@ -1019,6 +1213,65 @@ def set_mla_kv_buffer_triton(
1019
1213
  )
1020
1214
 
1021
1215
 
1216
+ @triton.jit
1217
+ def get_mla_kv_buffer_kernel(
1218
+ kv_buffer_ptr,
1219
+ cache_k_nope_ptr,
1220
+ cache_k_rope_ptr,
1221
+ loc_ptr,
1222
+ buffer_stride: tl.constexpr,
1223
+ nope_stride: tl.constexpr,
1224
+ rope_stride: tl.constexpr,
1225
+ nope_dim: tl.constexpr,
1226
+ rope_dim: tl.constexpr,
1227
+ ):
1228
+ pid_loc = tl.program_id(0)
1229
+ loc = tl.load(loc_ptr + pid_loc)
1230
+ loc_src_ptr = kv_buffer_ptr + loc * buffer_stride
1231
+
1232
+ nope_offs = tl.arange(0, nope_dim)
1233
+ nope_src_ptr = loc_src_ptr + nope_offs
1234
+ nope_src = tl.load(nope_src_ptr)
1235
+
1236
+ tl.store(
1237
+ cache_k_nope_ptr + pid_loc * nope_stride + nope_offs,
1238
+ nope_src,
1239
+ )
1240
+
1241
+ rope_offs = tl.arange(0, rope_dim)
1242
+ rope_src_ptr = loc_src_ptr + nope_dim + rope_offs
1243
+ rope_src = tl.load(rope_src_ptr)
1244
+ tl.store(
1245
+ cache_k_rope_ptr + pid_loc * rope_stride + rope_offs,
1246
+ rope_src,
1247
+ )
1248
+
1249
+
1250
+ def get_mla_kv_buffer_triton(
1251
+ kv_buffer: torch.Tensor,
1252
+ loc: torch.Tensor,
1253
+ cache_k_nope: torch.Tensor,
1254
+ cache_k_rope: torch.Tensor,
1255
+ ):
1256
+ # The source data type will be implicitly converted to the target data type.
1257
+ nope_dim = cache_k_nope.shape[-1] # 512
1258
+ rope_dim = cache_k_rope.shape[-1] # 64
1259
+ n_loc = loc.numel()
1260
+ grid = (n_loc,)
1261
+
1262
+ get_mla_kv_buffer_kernel[grid](
1263
+ kv_buffer,
1264
+ cache_k_nope,
1265
+ cache_k_rope,
1266
+ loc,
1267
+ kv_buffer.stride(0),
1268
+ cache_k_nope.stride(0),
1269
+ cache_k_rope.stride(0),
1270
+ nope_dim,
1271
+ rope_dim,
1272
+ )
1273
+
1274
+
1022
1275
  class MLATokenToKVPool(KVCache):
1023
1276
  def __init__(
1024
1277
  self,
@@ -1057,19 +1310,6 @@ class MLATokenToKVPool(KVCache):
1057
1310
  else (kv_lora_rank + qk_rope_head_dim)
1058
1311
  )
1059
1312
 
1060
- # for disagg with nvlink
1061
- self.enable_custom_mem_pool = get_bool_env_var(
1062
- "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
1063
- )
1064
- if self.enable_custom_mem_pool:
1065
- # TODO(shangming): abstract custom allocator class for more backends
1066
- from mooncake.allocator import NVLinkAllocator
1067
-
1068
- allocator = NVLinkAllocator.get_allocator(self.device)
1069
- self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
1070
- else:
1071
- self.custom_mem_pool = None
1072
-
1073
1313
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
1074
1314
  with (
1075
1315
  torch.cuda.use_mem_pool(self.custom_mem_pool)
@@ -1091,7 +1331,9 @@ class MLATokenToKVPool(KVCache):
1091
1331
  dtype=torch.uint64,
1092
1332
  device=self.device,
1093
1333
  )
1094
- self._finalize_allocation_log(size)
1334
+ if not use_nsa:
1335
+ # NSA will allocate indexer KV cache later and then log the total size
1336
+ self._finalize_allocation_log(size)
1095
1337
 
1096
1338
  def get_kv_size_bytes(self):
1097
1339
  assert hasattr(self, "kv_buffer")
@@ -1110,9 +1352,6 @@ class MLATokenToKVPool(KVCache):
1110
1352
  ]
1111
1353
  return kv_data_ptrs, kv_data_lens, kv_item_lens
1112
1354
 
1113
- def maybe_get_custom_mem_pool(self):
1114
- return self.custom_mem_pool
1115
-
1116
1355
  def get_key_buffer(self, layer_id: int):
1117
1356
  if self.layer_transfer_counter is not None:
1118
1357
  self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
@@ -1183,6 +1422,29 @@ class MLATokenToKVPool(KVCache):
1183
1422
  cache_k_rope,
1184
1423
  )
1185
1424
 
1425
+ def get_mla_kv_buffer(
1426
+ self,
1427
+ layer: RadixAttention,
1428
+ loc: torch.Tensor,
1429
+ dst_dtype: Optional[torch.dtype] = None,
1430
+ ):
1431
+ # get k nope and k rope from the kv buffer, and optionally cast them to dst_dtype.
1432
+ layer_id = layer.layer_id
1433
+ kv_buffer = self.get_key_buffer(layer_id)
1434
+ dst_dtype = dst_dtype or self.dtype
1435
+ cache_k_nope = torch.empty(
1436
+ (loc.shape[0], 1, self.kv_lora_rank),
1437
+ dtype=dst_dtype,
1438
+ device=kv_buffer.device,
1439
+ )
1440
+ cache_k_rope = torch.empty(
1441
+ (loc.shape[0], 1, self.qk_rope_head_dim),
1442
+ dtype=dst_dtype,
1443
+ device=kv_buffer.device,
1444
+ )
1445
+ get_mla_kv_buffer_triton(kv_buffer, loc, cache_k_nope, cache_k_rope)
1446
+ return cache_k_nope, cache_k_rope
1447
+
1186
1448
  def get_cpu_copy(self, indices):
1187
1449
  torch.cuda.synchronize()
1188
1450
  kv_cache_cpu = []
@@ -1212,6 +1474,9 @@ class MLATokenToKVPool(KVCache):
1212
1474
 
1213
1475
 
1214
1476
  class NSATokenToKVPool(MLATokenToKVPool):
1477
+ quant_block_size = 128
1478
+ index_k_with_scale_buffer_dtype = torch.uint8
1479
+
1215
1480
  def __init__(
1216
1481
  self,
1217
1482
  size: int,
@@ -1245,27 +1510,33 @@ class NSATokenToKVPool(MLATokenToKVPool):
1245
1510
  # num head == 1 and head dim == 128 for index_k in NSA
1246
1511
  assert index_head_dim == 128
1247
1512
 
1248
- self.quant_block_size = 128
1249
-
1250
1513
  assert self.page_size == 64
1251
- self.index_k_with_scale_buffer = [
1252
- torch.zeros(
1253
- # Layout:
1254
- # ref: test_attention.py :: kv_cache_cast_to_fp8
1255
- # shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
1256
- # data: for page i,
1257
- # * buf[i, :page_size * head_dim] for fp8 data
1258
- # * buf[i, page_size * head_dim:].view(float32) for scale
1259
- (
1260
- (size + page_size + 1) // self.page_size,
1261
- self.page_size
1262
- * (index_head_dim + index_head_dim // self.quant_block_size * 4),
1263
- ),
1264
- dtype=torch.uint8,
1265
- device=device,
1266
- )
1267
- for _ in range(layer_num)
1268
- ]
1514
+ with (
1515
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
1516
+ if self.custom_mem_pool
1517
+ else nullcontext()
1518
+ ):
1519
+ self.index_k_with_scale_buffer = [
1520
+ torch.zeros(
1521
+ # Layout:
1522
+ # ref: test_attention.py :: kv_cache_cast_to_fp8
1523
+ # shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
1524
+ # data: for page i,
1525
+ # * buf[i, :page_size * head_dim] for fp8 data
1526
+ # * buf[i, page_size * head_dim:].view(float32) for scale
1527
+ (
1528
+ (size + page_size + 1) // self.page_size,
1529
+ self.page_size
1530
+ * (
1531
+ index_head_dim + index_head_dim // self.quant_block_size * 4
1532
+ ),
1533
+ ),
1534
+ dtype=self.index_k_with_scale_buffer_dtype,
1535
+ device=device,
1536
+ )
1537
+ for _ in range(layer_num)
1538
+ ]
1539
+ self._finalize_allocation_log(size)
1269
1540
 
1270
1541
  def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
1271
1542
  if self.layer_transfer_counter is not None:
@@ -1307,6 +1578,24 @@ class NSATokenToKVPool(MLATokenToKVPool):
1307
1578
  pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
1308
1579
  )
1309
1580
 
1581
+ def get_state_buf_infos(self):
1582
+ data_ptrs = [
1583
+ self.index_k_with_scale_buffer[i].data_ptr() for i in range(self.layer_num)
1584
+ ]
1585
+ data_lens = [
1586
+ self.index_k_with_scale_buffer[i].nbytes for i in range(self.layer_num)
1587
+ ]
1588
+ item_lens = [
1589
+ self.index_k_with_scale_buffer[i][0].nbytes for i in range(self.layer_num)
1590
+ ]
1591
+ return data_ptrs, data_lens, item_lens
1592
+
1593
+ def get_kv_size_bytes(self):
1594
+ kv_size_bytes = super().get_kv_size_bytes()
1595
+ for index_k_cache in self.index_k_with_scale_buffer:
1596
+ kv_size_bytes += get_tensor_size_bytes(index_k_cache)
1597
+ return kv_size_bytes
1598
+
1310
1599
 
1311
1600
  class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1312
1601
  def __init__(
@@ -1531,27 +1820,38 @@ class DoubleSparseTokenToKVPool(KVCache):
1531
1820
  )
1532
1821
 
1533
1822
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
1534
- # [size, head_num, head_dim] for each layer
1535
- self.k_buffer = [
1536
- torch.zeros(
1537
- (size + page_size, head_num, head_dim), dtype=dtype, device=device
1538
- )
1539
- for _ in range(layer_num)
1540
- ]
1541
- self.v_buffer = [
1542
- torch.zeros(
1543
- (size + page_size, head_num, head_dim), dtype=dtype, device=device
1544
- )
1545
- for _ in range(layer_num)
1546
- ]
1823
+ with (
1824
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
1825
+ if self.enable_custom_mem_pool
1826
+ else nullcontext()
1827
+ ):
1828
+ # [size, head_num, head_dim] for each layer
1829
+ self.k_buffer = [
1830
+ torch.zeros(
1831
+ (size + page_size, head_num, head_dim),
1832
+ dtype=dtype,
1833
+ device=device,
1834
+ )
1835
+ for _ in range(layer_num)
1836
+ ]
1837
+ self.v_buffer = [
1838
+ torch.zeros(
1839
+ (size + page_size, head_num, head_dim),
1840
+ dtype=dtype,
1841
+ device=device,
1842
+ )
1843
+ for _ in range(layer_num)
1844
+ ]
1547
1845
 
1548
- # [size, head_num, heavy_channel_num] for each layer
1549
- self.label_buffer = [
1550
- torch.zeros(
1551
- (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
1552
- )
1553
- for _ in range(layer_num)
1554
- ]
1846
+ # [size, head_num, heavy_channel_num] for each layer
1847
+ self.label_buffer = [
1848
+ torch.zeros(
1849
+ (size + 1, head_num, heavy_channel_num),
1850
+ dtype=dtype,
1851
+ device=device,
1852
+ )
1853
+ for _ in range(layer_num)
1854
+ ]
1555
1855
 
1556
1856
  def get_key_buffer(self, layer_id: int):
1557
1857
  return self.k_buffer[layer_id - self.start_layer]
@@ -1584,38 +1884,36 @@ class DoubleSparseTokenToKVPool(KVCache):
1584
1884
 
1585
1885
 
1586
1886
  @triton.jit
1587
- def copy_all_layer_kv_cache(
1887
+ def copy_all_layer_kv_cache_tiled(
1588
1888
  data_ptrs,
1589
1889
  strides,
1590
1890
  tgt_loc_ptr,
1591
1891
  src_loc_ptr,
1592
1892
  num_locs,
1593
1893
  num_locs_upper: tl.constexpr,
1894
+ BYTES_PER_TILE: tl.constexpr,
1594
1895
  ):
1595
- BLOCK_SIZE: tl.constexpr = 128
1596
-
1896
+ """2D tiled kernel. Safe for in-place copy."""
1597
1897
  bid = tl.program_id(0)
1898
+ tid = tl.program_id(1)
1899
+
1598
1900
  stride = tl.load(strides + bid)
1901
+ base_ptr = tl.load(data_ptrs + bid)
1902
+ base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8))
1599
1903
 
1600
- data_ptr = tl.load(data_ptrs + bid)
1601
- data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
1904
+ byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE)
1905
+ mask_byte = byte_off < stride
1906
+ tl.multiple_of(byte_off, 16)
1602
1907
 
1603
- num_locs_offset = tl.arange(0, num_locs_upper)
1604
- tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
1605
- src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
1908
+ loc_idx = tl.arange(0, num_locs_upper)
1909
+ mask_loc = loc_idx < num_locs
1606
1910
 
1607
- # NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
1608
- # because this copy is an inplace operation.
1911
+ src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0)
1912
+ tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0)
1609
1913
 
1610
- num_loop = tl.cdiv(stride, BLOCK_SIZE)
1611
- for i in range(num_loop):
1612
- copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
1613
- mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :]
1614
- value = tl.load(
1615
- data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
1616
- )
1617
- tl.store(
1618
- data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
1619
- value,
1620
- mask=mask,
1621
- )
1914
+ src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :]
1915
+ tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :]
1916
+
1917
+ mask = mask_loc[:, None] & mask_byte[None, :]
1918
+ vals = tl.load(src_ptr, mask=mask)
1919
+ tl.store(tgt_ptr, vals, mask=mask)