sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,12 @@ limitations under the License.
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
18
+ from dataclasses import dataclass
19
+
20
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams
21
+ from sglang.srt.layers.attention.nsa import index_buf_accessor
22
+ from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
23
+ from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
19
24
 
20
25
  """
21
26
  Memory pool.
@@ -42,6 +47,8 @@ from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
42
47
 
43
48
  if TYPE_CHECKING:
44
49
  from sglang.srt.managers.cache_controller import LayerDoneCounter
50
+ from sglang.srt.managers.schedule_batch import Req
51
+
45
52
 
46
53
  logger = logging.getLogger(__name__)
47
54
 
@@ -107,92 +114,135 @@ class ReqToTokenPool:
107
114
 
108
115
 
109
116
  class MambaPool:
117
+ @dataclass(frozen=True, kw_only=True)
118
+ class State:
119
+ conv: torch.Tensor
120
+ temporal: torch.Tensor
121
+
122
+ def at_layer_idx(self, layer: int):
123
+ return type(self)(**{k: v[layer] for k, v in vars(self).items()})
124
+
125
+ def mem_usage_bytes(self):
126
+ return sum(get_tensor_size_bytes(t) for t in vars(self).values())
127
+
128
+ @dataclass(frozen=True, kw_only=True)
129
+ class SpeculativeState(State):
130
+ intermediate_ssm: torch.Tensor
131
+ intermediate_conv_window: torch.Tensor
132
+
110
133
  def __init__(
111
134
  self,
135
+ *,
112
136
  size: int,
113
- conv_dtype: torch.dtype,
114
- ssm_dtype: torch.dtype,
115
- num_mamba_layers: int,
116
- conv_state_shape: Tuple[int, int],
117
- temporal_state_shape: Tuple[int, int],
137
+ cache_params: "Mamba2CacheParams",
118
138
  device: str,
119
139
  speculative_num_draft_tokens: Optional[int] = None,
120
140
  ):
121
- conv_state = torch.zeros(
122
- size=(num_mamba_layers, size + 1) + conv_state_shape,
123
- dtype=conv_dtype,
124
- device=device,
125
- )
126
- temporal_state = torch.zeros(
127
- size=(num_mamba_layers, size + 1) + temporal_state_shape,
128
- dtype=ssm_dtype,
129
- device=device,
141
+ conv_state_shape = cache_params.shape.conv
142
+ temporal_state_shape = cache_params.shape.temporal
143
+ conv_dtype = cache_params.dtype.conv
144
+ ssm_dtype = cache_params.dtype.temporal
145
+ num_mamba_layers = len(cache_params.layers)
146
+
147
+ # for disagg with nvlink
148
+ self.enable_custom_mem_pool = get_bool_env_var(
149
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
130
150
  )
131
- if speculative_num_draft_tokens is not None:
132
- # Cache intermediate SSM states per draft token during target verify
133
- # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
134
- intermediate_ssm_state_cache = torch.zeros(
135
- size=(
136
- num_mamba_layers,
137
- size + 1,
138
- speculative_num_draft_tokens,
139
- temporal_state_shape[0],
140
- temporal_state_shape[1],
141
- temporal_state_shape[2],
142
- ),
143
- dtype=ssm_dtype,
144
- device="cuda",
145
- )
146
- # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
147
- # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
148
- intermediate_conv_window_cache = torch.zeros(
149
- size=(
150
- num_mamba_layers,
151
- size + 1,
152
- speculative_num_draft_tokens,
153
- conv_state_shape[0],
154
- conv_state_shape[1],
155
- ),
151
+ if self.enable_custom_mem_pool:
152
+ # TODO(shangming): abstract custom allocator class for more backends
153
+ from mooncake.allocator import NVLinkAllocator
154
+
155
+ allocator = NVLinkAllocator.get_allocator(self.device)
156
+ self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
157
+ else:
158
+ self.custom_mem_pool = None
159
+
160
+ with (
161
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
162
+ if self.enable_custom_mem_pool
163
+ else nullcontext()
164
+ ):
165
+ # assume conv_state = (dim, state_len)
166
+ assert conv_state_shape[0] > conv_state_shape[1]
167
+ conv_state = torch.zeros(
168
+ size=(num_mamba_layers, size + 1) + conv_state_shape,
156
169
  dtype=conv_dtype,
157
- device="cuda",
170
+ device=device,
158
171
  )
159
- self.mamba_cache = (
160
- conv_state,
161
- temporal_state,
162
- intermediate_ssm_state_cache,
163
- intermediate_conv_window_cache,
164
- )
165
- logger.info(
166
- f"Mamba Cache is allocated. "
167
- f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
168
- f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
169
- f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
170
- f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
172
+ temporal_state = torch.zeros(
173
+ size=(num_mamba_layers, size + 1) + temporal_state_shape,
174
+ dtype=ssm_dtype,
175
+ device=device,
171
176
  )
172
- else:
173
- self.mamba_cache = (conv_state, temporal_state)
174
- logger.info(
175
- f"Mamba Cache is allocated. "
176
- f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
177
- f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
177
+ if speculative_num_draft_tokens is not None:
178
+ # Cache intermediate SSM states per draft token during target verify
179
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
180
+ intermediate_ssm_state_cache = torch.zeros(
181
+ size=(
182
+ num_mamba_layers,
183
+ size + 1,
184
+ speculative_num_draft_tokens,
185
+ temporal_state_shape[0],
186
+ temporal_state_shape[1],
187
+ temporal_state_shape[2],
188
+ ),
189
+ dtype=ssm_dtype,
190
+ device="cuda",
191
+ )
192
+ # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
193
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
194
+ intermediate_conv_window_cache = torch.zeros(
195
+ size=(
196
+ num_mamba_layers,
197
+ size + 1,
198
+ speculative_num_draft_tokens,
199
+ conv_state_shape[0],
200
+ conv_state_shape[1],
201
+ ),
202
+ dtype=conv_dtype,
203
+ device="cuda",
204
+ )
205
+ self.mamba_cache = self.SpeculativeState(
206
+ conv=conv_state,
207
+ temporal=temporal_state,
208
+ intermediate_ssm=intermediate_ssm_state_cache,
209
+ intermediate_conv_window=intermediate_conv_window_cache,
210
+ )
211
+ logger.info(
212
+ f"Mamba Cache is allocated. "
213
+ f"max_mamba_cache_size: {size}, "
214
+ f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
215
+ f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
216
+ f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
217
+ f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
218
+ )
219
+ else:
220
+ self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
221
+ logger.info(
222
+ f"Mamba Cache is allocated. "
223
+ f"max_mamba_cache_size: {size}, "
224
+ f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
225
+ f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
226
+ )
227
+ self.size = size
228
+ self.device = device
229
+ self.free_slots = torch.arange(
230
+ self.size, dtype=torch.int64, device=self.device
178
231
  )
179
- self.size = size
180
- self.free_slots = list(range(size))
181
- self.mem_usage = self.get_mamba_size() / GB
182
-
183
- def get_mamba_params_all_layers(self):
184
- return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
232
+ self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
233
+ self.num_mamba_layers = num_mamba_layers
185
234
 
186
- def get_mamba_params(self, layer_id: int):
187
- return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
235
+ def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
236
+ assert isinstance(self.mamba_cache, self.SpeculativeState)
237
+ return self.mamba_cache
188
238
 
189
- def get_mamba_size(self):
190
- return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
239
+ def mamba2_layer_cache(self, layer_id: int):
240
+ return self.mamba_cache.at_layer_idx(layer_id)
191
241
 
192
242
  def available_size(self):
193
243
  return len(self.free_slots)
194
244
 
195
- def alloc(self, need_size: int) -> Optional[List[int]]:
245
+ def alloc(self, need_size: int) -> Optional[torch.Tensor]:
196
246
  if need_size > len(self.free_slots):
197
247
  return None
198
248
 
@@ -201,15 +251,46 @@ class MambaPool:
201
251
 
202
252
  return select_index
203
253
 
204
- def free(self, free_index: Union[int, List[int]]):
205
- if isinstance(free_index, (int,)):
206
- self.free_slots.append(free_index)
207
- else:
208
- self.free_slots.extend(free_index)
209
- self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
254
+ def free(self, free_index: torch.Tensor):
255
+ if free_index.numel() == 0:
256
+ return
257
+ self.free_slots = torch.cat((self.free_slots, free_index))
258
+ self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
259
+ :, free_index
260
+ ] = 0
210
261
 
211
262
  def clear(self):
212
- self.free_slots = list(range(self.size))
263
+ self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
264
+
265
+ def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
266
+ self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
267
+ self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
268
+ :, src_index
269
+ ]
270
+ return
271
+
272
+ def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]:
273
+ dst_index = self.alloc(1)
274
+ if dst_index == None:
275
+ return None
276
+ self.copy_from(src_index, dst_index)
277
+ return dst_index
278
+
279
+ def get_contiguous_buf_infos(self):
280
+ state_tensors = [
281
+ getattr(self.mamba_cache, field) for field in vars(self.mamba_cache)
282
+ ]
283
+ data_ptrs, data_lens, item_lens = [], [], []
284
+
285
+ for _, state_tensor in enumerate(state_tensors):
286
+ data_ptrs += [
287
+ state_tensor[i].data_ptr() for i in range(self.num_mamba_layers)
288
+ ]
289
+ data_lens += [state_tensor[i].nbytes for i in range(self.num_mamba_layers)]
290
+ item_lens += [
291
+ state_tensor[i][0].nbytes for i in range(self.num_mamba_layers)
292
+ ]
293
+ return data_ptrs, data_lens, item_lens
213
294
 
214
295
 
215
296
  class HybridReqToTokenPool(ReqToTokenPool):
@@ -217,16 +298,14 @@ class HybridReqToTokenPool(ReqToTokenPool):
217
298
 
218
299
  def __init__(
219
300
  self,
301
+ *,
220
302
  size: int,
303
+ mamba_size: int,
221
304
  max_context_len: int,
222
305
  device: str,
223
306
  enable_memory_saver: bool,
224
- conv_dtype: torch.dtype,
225
- ssm_dtype: torch.dtype,
226
- mamba_layers: List[int],
227
- conv_state_shape: Tuple[int, int],
228
- temporal_state_shape: Tuple[int, int],
229
- speculative_num_draft_tokens: int,
307
+ cache_params: "Mamba2CacheParams",
308
+ speculative_num_draft_tokens: int = None,
230
309
  ):
231
310
  super().__init__(
232
311
  size=size,
@@ -234,31 +313,37 @@ class HybridReqToTokenPool(ReqToTokenPool):
234
313
  device=device,
235
314
  enable_memory_saver=enable_memory_saver,
236
315
  )
316
+ self._init_mamba_pool(
317
+ size=mamba_size,
318
+ cache_params=cache_params,
319
+ device=device,
320
+ speculative_num_draft_tokens=speculative_num_draft_tokens,
321
+ )
237
322
 
323
+ def _init_mamba_pool(
324
+ self,
325
+ size: int,
326
+ cache_params: "Mamba2CacheParams",
327
+ device: str,
328
+ speculative_num_draft_tokens: int = None,
329
+ ):
238
330
  self.mamba_pool = MambaPool(
239
- size,
240
- conv_dtype,
241
- ssm_dtype,
242
- len(mamba_layers),
243
- conv_state_shape,
244
- temporal_state_shape,
245
- device,
246
- speculative_num_draft_tokens,
331
+ size=size,
332
+ cache_params=cache_params,
333
+ device=device,
334
+ speculative_num_draft_tokens=speculative_num_draft_tokens,
247
335
  )
248
- self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
336
+ self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
249
337
 
250
338
  self.device = device
251
339
  self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
252
340
  size, dtype=torch.int32, device=self.device
253
341
  )
254
342
 
255
- self.rid_to_mamba_index_mapping: Dict[str, int] = {}
256
- self.mamba_index_to_rid_mapping: Dict[int, str] = {}
257
-
258
343
  # For chunk prefill req, we do not need to allocate mamba cache,
259
344
  # We could use allocated mamba cache instead.
260
345
  def alloc(
261
- self, need_size: int, reqs: Optional[List["Req"]] = None
346
+ self, need_size: int, reqs: Optional[List[Req]] = None
262
347
  ) -> Optional[List[int]]:
263
348
  select_index = super().alloc(need_size)
264
349
  if select_index == None:
@@ -266,14 +351,14 @@ class HybridReqToTokenPool(ReqToTokenPool):
266
351
 
267
352
  mamba_index = []
268
353
  for req in reqs:
269
- rid = req.rid
270
- if rid in self.rid_to_mamba_index_mapping:
271
- mid = self.rid_to_mamba_index_mapping[rid]
272
- elif (mid := self.mamba_pool.alloc(1)) is not None:
273
- mid = mid[0]
274
- self.rid_to_mamba_index_mapping[rid] = mid
275
- self.mamba_index_to_rid_mapping[mid] = rid
276
- mamba_index.append(mid)
354
+ mid = None
355
+ if req.mamba_pool_idx is not None: # for radix cache
356
+ mid = req.mamba_pool_idx
357
+ else:
358
+ mid = self.mamba_pool.alloc(1)[0]
359
+ req.mamba_pool_idx = mid
360
+ if mid is not None:
361
+ mamba_index.append(mid)
277
362
  assert len(select_index) == len(
278
363
  mamba_index
279
364
  ), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
@@ -285,26 +370,21 @@ class HybridReqToTokenPool(ReqToTokenPool):
285
370
  def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
286
371
  return self.req_index_to_mamba_index_mapping[req_indices]
287
372
 
288
- def get_mamba_params(self, layer_id: int):
373
+ def mamba2_layer_cache(self, layer_id: int):
289
374
  assert layer_id in self.mamba_map
290
- return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
375
+ return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
291
376
 
292
- def get_mamba_params_all_layers(self):
293
- return self.mamba_pool.get_mamba_params_all_layers()
377
+ def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
378
+ return self.mamba_pool.get_speculative_mamba2_params_all_layers()
294
379
 
295
380
  # For chunk prefill, we can not free mamba cache, we need use it in the future
296
381
  def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
382
+ if isinstance(free_index, (int,)):
383
+ free_index = [free_index]
297
384
  super().free(free_index)
298
385
  if free_mamba_cache:
299
386
  mamba_index = self.req_index_to_mamba_index_mapping[free_index]
300
- mamba_index_list = mamba_index.tolist()
301
- if isinstance(mamba_index_list, int):
302
- mamba_index_list = [mamba_index_list]
303
- self.mamba_pool.free(mamba_index_list)
304
- for mid in mamba_index_list:
305
- rid = self.mamba_index_to_rid_mapping[mid]
306
- self.mamba_index_to_rid_mapping.pop(mid)
307
- self.rid_to_mamba_index_mapping.pop(rid)
387
+ self.mamba_pool.free(mamba_index)
308
388
 
309
389
  def clear(self):
310
390
  super().clear()
@@ -347,6 +427,19 @@ class KVCache(abc.ABC):
347
427
  # default state for optional layer-wise transfer control
348
428
  self.layer_transfer_counter = None
349
429
 
430
+ # for disagg with nvlink
431
+ self.enable_custom_mem_pool = get_bool_env_var(
432
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
433
+ )
434
+ if self.enable_custom_mem_pool:
435
+ # TODO(shangming): abstract custom allocator class for more backends
436
+ from mooncake.allocator import NVLinkAllocator
437
+
438
+ allocator = NVLinkAllocator.get_allocator(self.device)
439
+ self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
440
+ else:
441
+ self.custom_mem_pool = None
442
+
350
443
  def _finalize_allocation_log(self, num_tokens: int):
351
444
  """Common logging and mem_usage computation for KV cache allocation.
352
445
  Supports both tuple (K, V) size returns and single KV size returns.
@@ -398,6 +491,9 @@ class KVCache(abc.ABC):
398
491
  def load_cpu_copy(self, kv_cache_cpu, indices):
399
492
  raise NotImplementedError()
400
493
 
494
+ def maybe_get_custom_mem_pool(self):
495
+ return self.custom_mem_pool
496
+
401
497
 
402
498
  class MHATokenToKVPool(KVCache):
403
499
 
@@ -413,6 +509,7 @@ class MHATokenToKVPool(KVCache):
413
509
  enable_memory_saver: bool,
414
510
  start_layer: Optional[int] = None,
415
511
  end_layer: Optional[int] = None,
512
+ enable_kv_cache_copy: bool = False,
416
513
  ):
417
514
  super().__init__(
418
515
  size,
@@ -427,25 +524,61 @@ class MHATokenToKVPool(KVCache):
427
524
  self.head_num = head_num
428
525
  self.head_dim = head_dim
429
526
 
430
- # for disagg with nvlink
431
- self.enable_custom_mem_pool = get_bool_env_var(
432
- "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
433
- )
434
- if self.enable_custom_mem_pool:
435
- # TODO(shangming): abstract custom allocator class for more backends
436
- from mooncake.allocator import NVLinkAllocator
437
-
438
- allocator = NVLinkAllocator.get_allocator(self.device)
439
- self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
440
- else:
441
- self.custom_mem_pool = None
442
-
443
527
  self._create_buffers()
444
528
 
445
529
  self.device_module = torch.get_device_module(self.device)
446
530
  self.alt_stream = self.device_module.Stream() if _is_cuda else None
531
+
532
+ if enable_kv_cache_copy:
533
+ self._init_kv_copy_and_warmup()
534
+ else:
535
+ self._kv_copy_config = None
536
+
447
537
  self._finalize_allocation_log(size)
448
538
 
539
+ def _init_kv_copy_and_warmup(self):
540
+ # Heuristics for KV copy tiling
541
+ _KV_COPY_STRIDE_THRESHOLD_LARGE = 8192
542
+ _KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096
543
+ _KV_COPY_TILE_SIZE_LARGE = 512
544
+ _KV_COPY_TILE_SIZE_MEDIUM = 256
545
+ _KV_COPY_TILE_SIZE_SMALL = 128
546
+ _KV_COPY_NUM_WARPS_LARGE_TILE = 8
547
+ _KV_COPY_NUM_WARPS_SMALL_TILE = 4
548
+
549
+ stride_bytes = int(self.data_strides[0].item())
550
+ if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE:
551
+ bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE
552
+ elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM:
553
+ bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM
554
+ else:
555
+ bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL
556
+
557
+ self._kv_copy_config = {
558
+ "bytes_per_tile": bytes_per_tile,
559
+ "byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile,
560
+ "num_warps": (
561
+ _KV_COPY_NUM_WARPS_SMALL_TILE
562
+ if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM
563
+ else _KV_COPY_NUM_WARPS_LARGE_TILE
564
+ ),
565
+ }
566
+
567
+ dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device)
568
+ grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"])
569
+
570
+ copy_all_layer_kv_cache_tiled[grid](
571
+ self.data_ptrs,
572
+ self.data_strides,
573
+ dummy_loc,
574
+ dummy_loc,
575
+ 1,
576
+ 1,
577
+ BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"],
578
+ num_warps=self._kv_copy_config["num_warps"],
579
+ num_stages=2,
580
+ )
581
+
449
582
  def _create_buffers(self):
450
583
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
451
584
  with (
@@ -533,9 +666,6 @@ class MHATokenToKVPool(KVCache):
533
666
  ]
534
667
  return kv_data_ptrs, kv_data_lens, kv_item_lens
535
668
 
536
- def maybe_get_custom_mem_pool(self):
537
- return self.custom_mem_pool
538
-
539
669
  def get_cpu_copy(self, indices):
540
670
  torch.cuda.synchronize()
541
671
  kv_cache_cpu = []
@@ -640,13 +770,28 @@ class MHATokenToKVPool(KVCache):
640
770
  self.v_buffer[layer_id - self.start_layer][loc] = cache_v
641
771
 
642
772
  def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
643
- copy_all_layer_kv_cache[(len(self.data_ptrs),)](
773
+ N = tgt_loc.numel()
774
+ if N == 0:
775
+ return
776
+
777
+ assert (
778
+ self._kv_copy_config is not None
779
+ ), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__"
780
+
781
+ cfg = self._kv_copy_config
782
+ N_upper = next_power_of_2(N)
783
+ grid = (self.data_ptrs.numel(), cfg["byte_tiles"])
784
+
785
+ copy_all_layer_kv_cache_tiled[grid](
644
786
  self.data_ptrs,
645
787
  self.data_strides,
646
788
  tgt_loc,
647
789
  src_loc,
648
- len(tgt_loc),
649
- next_power_of_2(len(tgt_loc)),
790
+ N,
791
+ N_upper,
792
+ BYTES_PER_TILE=cfg["bytes_per_tile"],
793
+ num_warps=cfg["num_warps"],
794
+ num_stages=2,
650
795
  )
651
796
 
652
797
 
@@ -663,12 +808,18 @@ class HybridLinearKVPool(KVCache):
663
808
  full_attention_layer_ids: List[int],
664
809
  enable_kvcache_transpose: bool,
665
810
  device: str,
811
+ mamba_pool: MambaPool,
666
812
  ):
667
813
  self.size = size
668
814
  self.dtype = dtype
669
815
  self.device = device
670
816
  self.full_layer_nums = len(full_attention_layer_ids)
671
817
  self.page_size = page_size
818
+ # TODO support pp?
819
+ self.start_layer = 0
820
+ self.head_num = head_num
821
+ self.head_dim = head_dim
822
+ self.mamba_pool = mamba_pool
672
823
  # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
673
824
  assert not enable_kvcache_transpose
674
825
  if _is_npu:
@@ -697,6 +848,15 @@ class HybridLinearKVPool(KVCache):
697
848
  def get_contiguous_buf_infos(self):
698
849
  return self.full_kv_pool.get_contiguous_buf_infos()
699
850
 
851
+ def get_state_buf_infos(self):
852
+ mamba_data_ptrs, mamba_data_lens, mamba_item_lens = (
853
+ self.mamba_pool.get_contiguous_buf_infos()
854
+ )
855
+ return mamba_data_ptrs, mamba_data_lens, mamba_item_lens
856
+
857
+ def maybe_get_custom_mem_pool(self):
858
+ return self.full_kv_pool.maybe_get_custom_mem_pool()
859
+
700
860
  def _transfer_full_attention_id(self, layer_id: int):
701
861
  if layer_id not in self.full_attention_layer_id_mapping:
702
862
  raise ValueError(
@@ -747,28 +907,57 @@ class SWAKVPool(KVCache):
747
907
  self,
748
908
  size: int,
749
909
  size_swa: int,
910
+ dtype: torch.dtype,
911
+ head_num: int,
912
+ head_dim: int,
750
913
  swa_attention_layer_ids: List[int],
751
914
  full_attention_layer_ids: List[int],
752
915
  enable_kvcache_transpose: bool,
916
+ device: str,
753
917
  token_to_kv_pool_class: KVCache = MHATokenToKVPool,
754
918
  **kwargs,
755
919
  ):
756
920
  self.size = size
757
921
  self.size_swa = size_swa
922
+ self.dtype = dtype
923
+ self.head_num = head_num
924
+ self.head_dim = head_dim
925
+ self.device = device
758
926
  self.swa_layer_nums = len(swa_attention_layer_ids)
759
927
  self.full_layer_nums = len(full_attention_layer_ids)
928
+ self.start_layer = 0
929
+ self.page_size = 1
930
+
760
931
  kwargs["page_size"] = 1
761
932
  kwargs["enable_memory_saver"] = False
933
+ kwargs["head_num"] = head_num
934
+ kwargs["head_dim"] = head_dim
935
+ kwargs["device"] = device
762
936
  # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
763
937
  assert not enable_kvcache_transpose
764
938
 
939
+ # for disagg with nvlink
940
+ self.enable_custom_mem_pool = get_bool_env_var(
941
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
942
+ )
943
+ if self.enable_custom_mem_pool:
944
+ # TODO(shangming): abstract custom allocator class for more backends
945
+ from mooncake.allocator import NVLinkAllocator
946
+
947
+ allocator = NVLinkAllocator.get_allocator(self.device)
948
+ self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
949
+ else:
950
+ self.custom_mem_pool = None
951
+
765
952
  self.swa_kv_pool = token_to_kv_pool_class(
766
953
  size=size_swa,
954
+ dtype=dtype,
767
955
  layer_num=self.swa_layer_nums,
768
956
  **kwargs,
769
957
  )
770
958
  self.full_kv_pool = token_to_kv_pool_class(
771
959
  size=size,
960
+ dtype=dtype,
772
961
  layer_num=self.full_layer_nums,
773
962
  **kwargs,
774
963
  )
@@ -781,6 +970,9 @@ class SWAKVPool(KVCache):
781
970
 
782
971
  k_size, v_size = self.get_kv_size_bytes()
783
972
  self.mem_usage = (k_size + v_size) / GB
973
+ logger.info(
974
+ f"SWAKVPool mem usage: {self.mem_usage} GB, swa size: {self.size_swa}, full size: {self.size}"
975
+ )
784
976
 
785
977
  def get_kv_size_bytes(self):
786
978
  k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
@@ -791,15 +983,19 @@ class SWAKVPool(KVCache):
791
983
  full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
792
984
  self.full_kv_pool.get_contiguous_buf_infos()
793
985
  )
986
+
987
+ kv_data_ptrs = full_kv_data_ptrs
988
+ kv_data_lens = full_kv_data_lens
989
+ kv_item_lens = full_kv_item_lens
990
+
991
+ return kv_data_ptrs, kv_data_lens, kv_item_lens
992
+
993
+ def get_state_buf_infos(self):
794
994
  swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
795
995
  self.swa_kv_pool.get_contiguous_buf_infos()
796
996
  )
797
997
 
798
- kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
799
- kv_data_lens = full_kv_data_lens + swa_kv_data_lens
800
- kv_item_lens = full_kv_item_lens + swa_kv_item_lens
801
-
802
- return kv_data_ptrs, kv_data_lens, kv_item_lens
998
+ return swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens
803
999
 
804
1000
  def get_key_buffer(self, layer_id: int):
805
1001
  layer_id_pool, is_swa = self.layers_mapping[layer_id]
@@ -1030,6 +1226,8 @@ class MLATokenToKVPool(KVCache):
1030
1226
  enable_memory_saver: bool,
1031
1227
  start_layer: Optional[int] = None,
1032
1228
  end_layer: Optional[int] = None,
1229
+ use_nsa: bool = False,
1230
+ override_kv_cache_dim: Optional[int] = None,
1033
1231
  ):
1034
1232
  super().__init__(
1035
1233
  size,
@@ -1044,19 +1242,14 @@ class MLATokenToKVPool(KVCache):
1044
1242
 
1045
1243
  self.kv_lora_rank = kv_lora_rank
1046
1244
  self.qk_rope_head_dim = qk_rope_head_dim
1047
-
1048
- # for disagg with nvlink
1049
- self.enable_custom_mem_pool = get_bool_env_var(
1050
- "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
1245
+ self.use_nsa = use_nsa
1246
+ self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn
1247
+ # TODO do not hardcode
1248
+ self.kv_cache_dim = (
1249
+ 656
1250
+ if self.use_nsa and self.nsa_kv_cache_store_fp8
1251
+ else (kv_lora_rank + qk_rope_head_dim)
1051
1252
  )
1052
- if self.enable_custom_mem_pool:
1053
- # TODO(shangming): abstract custom allocator class for more backends
1054
- from mooncake.allocator import NVLinkAllocator
1055
-
1056
- allocator = NVLinkAllocator.get_allocator(self.device)
1057
- self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
1058
- else:
1059
- self.custom_mem_pool = None
1060
1253
 
1061
1254
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
1062
1255
  with (
@@ -1067,7 +1260,7 @@ class MLATokenToKVPool(KVCache):
1067
1260
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
1068
1261
  self.kv_buffer = [
1069
1262
  torch.zeros(
1070
- (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
1263
+ (size + page_size, 1, self.kv_cache_dim),
1071
1264
  dtype=self.store_dtype,
1072
1265
  device=device,
1073
1266
  )
@@ -1079,7 +1272,9 @@ class MLATokenToKVPool(KVCache):
1079
1272
  dtype=torch.uint64,
1080
1273
  device=self.device,
1081
1274
  )
1082
- self._finalize_allocation_log(size)
1275
+ if not use_nsa:
1276
+ # NSA will allocate indexer KV cache later and then log the total size
1277
+ self._finalize_allocation_log(size)
1083
1278
 
1084
1279
  def get_kv_size_bytes(self):
1085
1280
  assert hasattr(self, "kv_buffer")
@@ -1098,9 +1293,6 @@ class MLATokenToKVPool(KVCache):
1098
1293
  ]
1099
1294
  return kv_data_ptrs, kv_data_lens, kv_item_lens
1100
1295
 
1101
- def maybe_get_custom_mem_pool(self):
1102
- return self.custom_mem_pool
1103
-
1104
1296
  def get_key_buffer(self, layer_id: int):
1105
1297
  if self.layer_transfer_counter is not None:
1106
1298
  self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
@@ -1130,6 +1322,7 @@ class MLATokenToKVPool(KVCache):
1130
1322
  cache_v: torch.Tensor,
1131
1323
  ):
1132
1324
  layer_id = layer.layer_id
1325
+ assert not (self.use_nsa and self.nsa_kv_cache_store_fp8)
1133
1326
  if cache_k.dtype != self.dtype:
1134
1327
  cache_k = cache_k.to(self.dtype)
1135
1328
  if self.store_dtype != self.dtype:
@@ -1147,16 +1340,28 @@ class MLATokenToKVPool(KVCache):
1147
1340
  cache_k_rope: torch.Tensor,
1148
1341
  ):
1149
1342
  layer_id = layer.layer_id
1150
- if cache_k_nope.dtype != self.dtype:
1151
- cache_k_nope = cache_k_nope.to(self.dtype)
1152
- cache_k_rope = cache_k_rope.to(self.dtype)
1153
- if self.store_dtype != self.dtype:
1154
- cache_k_nope = cache_k_nope.view(self.store_dtype)
1155
- cache_k_rope = cache_k_rope.view(self.store_dtype)
1156
1343
 
1157
- set_mla_kv_buffer_triton(
1158
- self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
1159
- )
1344
+ if self.use_nsa and self.nsa_kv_cache_store_fp8:
1345
+ # original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here
1346
+ # TODO no need to cat
1347
+ cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1)
1348
+ cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1)
1349
+ cache_k = cache_k.view(self.store_dtype)
1350
+ self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
1351
+ else:
1352
+ if cache_k_nope.dtype != self.dtype:
1353
+ cache_k_nope = cache_k_nope.to(self.dtype)
1354
+ cache_k_rope = cache_k_rope.to(self.dtype)
1355
+ if self.store_dtype != self.dtype:
1356
+ cache_k_nope = cache_k_nope.view(self.store_dtype)
1357
+ cache_k_rope = cache_k_rope.view(self.store_dtype)
1358
+
1359
+ set_mla_kv_buffer_triton(
1360
+ self.kv_buffer[layer_id - self.start_layer],
1361
+ loc,
1362
+ cache_k_nope,
1363
+ cache_k_rope,
1364
+ )
1160
1365
 
1161
1366
  def get_cpu_copy(self, indices):
1162
1367
  torch.cuda.synchronize()
@@ -1186,6 +1391,130 @@ class MLATokenToKVPool(KVCache):
1186
1391
  torch.cuda.synchronize()
1187
1392
 
1188
1393
 
1394
+ class NSATokenToKVPool(MLATokenToKVPool):
1395
+ quant_block_size = 128
1396
+ index_k_with_scale_buffer_dtype = torch.uint8
1397
+
1398
+ def __init__(
1399
+ self,
1400
+ size: int,
1401
+ page_size: int,
1402
+ kv_lora_rank: int,
1403
+ dtype: torch.dtype,
1404
+ qk_rope_head_dim: int,
1405
+ layer_num: int,
1406
+ device: str,
1407
+ index_head_dim: int,
1408
+ enable_memory_saver: bool,
1409
+ start_layer: Optional[int] = None,
1410
+ end_layer: Optional[int] = None,
1411
+ ):
1412
+ super().__init__(
1413
+ size,
1414
+ page_size,
1415
+ dtype,
1416
+ kv_lora_rank,
1417
+ qk_rope_head_dim,
1418
+ layer_num,
1419
+ device,
1420
+ enable_memory_saver,
1421
+ start_layer,
1422
+ end_layer,
1423
+ use_nsa=True,
1424
+ )
1425
+ # self.index_k_dtype = torch.float8_e4m3fn
1426
+ # self.index_k_scale_dtype = torch.float32
1427
+ self.index_head_dim = index_head_dim
1428
+ # num head == 1 and head dim == 128 for index_k in NSA
1429
+ assert index_head_dim == 128
1430
+
1431
+ assert self.page_size == 64
1432
+ with (
1433
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
1434
+ if self.custom_mem_pool
1435
+ else nullcontext()
1436
+ ):
1437
+ self.index_k_with_scale_buffer = [
1438
+ torch.zeros(
1439
+ # Layout:
1440
+ # ref: test_attention.py :: kv_cache_cast_to_fp8
1441
+ # shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
1442
+ # data: for page i,
1443
+ # * buf[i, :page_size * head_dim] for fp8 data
1444
+ # * buf[i, page_size * head_dim:].view(float32) for scale
1445
+ (
1446
+ (size + page_size + 1) // self.page_size,
1447
+ self.page_size
1448
+ * (
1449
+ index_head_dim + index_head_dim // self.quant_block_size * 4
1450
+ ),
1451
+ ),
1452
+ dtype=self.index_k_with_scale_buffer_dtype,
1453
+ device=device,
1454
+ )
1455
+ for _ in range(layer_num)
1456
+ ]
1457
+ self._finalize_allocation_log(size)
1458
+
1459
+ def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
1460
+ if self.layer_transfer_counter is not None:
1461
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
1462
+ return self.index_k_with_scale_buffer[layer_id - self.start_layer]
1463
+
1464
+ def get_index_k_continuous(
1465
+ self,
1466
+ layer_id: int,
1467
+ seq_len: int,
1468
+ page_indices: torch.Tensor,
1469
+ ):
1470
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
1471
+ return index_buf_accessor.GetK.execute(
1472
+ self, buf, seq_len=seq_len, page_indices=page_indices
1473
+ )
1474
+
1475
+ def get_index_k_scale_continuous(
1476
+ self,
1477
+ layer_id: int,
1478
+ seq_len: int,
1479
+ page_indices: torch.Tensor,
1480
+ ):
1481
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
1482
+ return index_buf_accessor.GetS.execute(
1483
+ self, buf, seq_len=seq_len, page_indices=page_indices
1484
+ )
1485
+
1486
+ # TODO rename later (currently use diff name to avoid confusion)
1487
+ def set_index_k_and_scale_buffer(
1488
+ self,
1489
+ layer_id: int,
1490
+ loc: torch.Tensor,
1491
+ index_k: torch.Tensor,
1492
+ index_k_scale: torch.Tensor,
1493
+ ) -> None:
1494
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
1495
+ index_buf_accessor.SetKAndS.execute(
1496
+ pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
1497
+ )
1498
+
1499
+ def get_state_buf_infos(self):
1500
+ data_ptrs = [
1501
+ self.index_k_with_scale_buffer[i].data_ptr() for i in range(self.layer_num)
1502
+ ]
1503
+ data_lens = [
1504
+ self.index_k_with_scale_buffer[i].nbytes for i in range(self.layer_num)
1505
+ ]
1506
+ item_lens = [
1507
+ self.index_k_with_scale_buffer[i][0].nbytes for i in range(self.layer_num)
1508
+ ]
1509
+ return data_ptrs, data_lens, item_lens
1510
+
1511
+ def get_kv_size_bytes(self):
1512
+ kv_size_bytes = super().get_kv_size_bytes()
1513
+ for index_k_cache in self.index_k_with_scale_buffer:
1514
+ kv_size_bytes += get_tensor_size_bytes(index_k_cache)
1515
+ return kv_size_bytes
1516
+
1517
+
1189
1518
  class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1190
1519
  def __init__(
1191
1520
  self,
@@ -1194,6 +1523,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1194
1523
  dtype: torch.dtype,
1195
1524
  kv_lora_rank: int,
1196
1525
  qk_rope_head_dim: int,
1526
+ index_head_dim: Optional[int],
1197
1527
  layer_num: int,
1198
1528
  device: str,
1199
1529
  enable_memory_saver: bool,
@@ -1213,6 +1543,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1213
1543
 
1214
1544
  self.kv_lora_rank = kv_lora_rank
1215
1545
  self.qk_rope_head_dim = qk_rope_head_dim
1546
+ self.index_head_dim = index_head_dim
1216
1547
 
1217
1548
  self.custom_mem_pool = None
1218
1549
 
@@ -1240,6 +1571,18 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1240
1571
  dtype=self.store_dtype,
1241
1572
  device=self.device,
1242
1573
  )
1574
+ if self.index_head_dim is not None:
1575
+ self.index_k_buffer = torch.zeros(
1576
+ (
1577
+ layer_num,
1578
+ self.size // self.page_size + 1,
1579
+ self.page_size,
1580
+ 1,
1581
+ self.index_head_dim,
1582
+ ),
1583
+ dtype=self.store_dtype,
1584
+ device=self.device,
1585
+ )
1243
1586
 
1244
1587
  self._finalize_allocation_log(size)
1245
1588
 
@@ -1251,6 +1594,10 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1251
1594
  kv_size_bytes += get_tensor_size_bytes(k_cache)
1252
1595
  for v_cache in self.v_buffer:
1253
1596
  kv_size_bytes += get_tensor_size_bytes(v_cache)
1597
+ if self.index_head_dim is not None:
1598
+ assert hasattr(self, "index_k_buffer")
1599
+ for index_k_cache in self.index_k_buffer:
1600
+ kv_size_bytes += get_tensor_size_bytes(index_k_cache)
1254
1601
  return kv_size_bytes
1255
1602
 
1256
1603
  def get_kv_buffer(self, layer_id: int):
@@ -1277,6 +1624,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1277
1624
  return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
1278
1625
  return self.v_buffer[layer_id - self.start_layer]
1279
1626
 
1627
+ def get_index_k_buffer(self, layer_id: int):
1628
+ if self.layer_transfer_counter is not None:
1629
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
1630
+
1631
+ if self.store_dtype != self.dtype:
1632
+ return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)
1633
+ return self.index_k_buffer[layer_id - self.start_layer]
1634
+
1280
1635
  # for disagg
1281
1636
  def get_contiguous_buf_infos(self):
1282
1637
  # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
@@ -1289,6 +1644,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1289
1644
  kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
1290
1645
  self.v_buffer[i][0].nbytes for i in range(self.layer_num)
1291
1646
  ]
1647
+ if self.index_head_dim is not None:
1648
+ kv_data_ptrs += [
1649
+ self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)
1650
+ ]
1651
+ kv_data_lens += [
1652
+ self.index_k_buffer[i].nbytes for i in range(self.layer_num)
1653
+ ]
1654
+ kv_item_lens += [
1655
+ self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)
1656
+ ]
1292
1657
  return kv_data_ptrs, kv_data_lens, kv_item_lens
1293
1658
 
1294
1659
  def set_kv_buffer(
@@ -1325,6 +1690,26 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1325
1690
  cache_v.view(-1, 1, self.qk_rope_head_dim),
1326
1691
  )
1327
1692
 
1693
+ def set_index_k_buffer(
1694
+ self,
1695
+ layer_id: int,
1696
+ loc: torch.Tensor,
1697
+ index_k: torch.Tensor,
1698
+ ):
1699
+ if index_k.dtype != self.dtype:
1700
+ index_k = index_k.to(self.dtype)
1701
+
1702
+ if self.store_dtype != self.dtype:
1703
+ index_k = index_k.view(self.store_dtype)
1704
+
1705
+ torch_npu.npu_scatter_nd_update_(
1706
+ self.index_k_buffer[layer_id - self.start_layer].view(
1707
+ -1, 1, self.index_head_dim
1708
+ ),
1709
+ loc.view(-1, 1),
1710
+ index_k.view(-1, 1, self.index_head_dim),
1711
+ )
1712
+
1328
1713
 
1329
1714
  class DoubleSparseTokenToKVPool(KVCache):
1330
1715
  def __init__(
@@ -1353,27 +1738,38 @@ class DoubleSparseTokenToKVPool(KVCache):
1353
1738
  )
1354
1739
 
1355
1740
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
1356
- # [size, head_num, head_dim] for each layer
1357
- self.k_buffer = [
1358
- torch.zeros(
1359
- (size + page_size, head_num, head_dim), dtype=dtype, device=device
1360
- )
1361
- for _ in range(layer_num)
1362
- ]
1363
- self.v_buffer = [
1364
- torch.zeros(
1365
- (size + page_size, head_num, head_dim), dtype=dtype, device=device
1366
- )
1367
- for _ in range(layer_num)
1368
- ]
1741
+ with (
1742
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
1743
+ if self.enable_custom_mem_pool
1744
+ else nullcontext()
1745
+ ):
1746
+ # [size, head_num, head_dim] for each layer
1747
+ self.k_buffer = [
1748
+ torch.zeros(
1749
+ (size + page_size, head_num, head_dim),
1750
+ dtype=dtype,
1751
+ device=device,
1752
+ )
1753
+ for _ in range(layer_num)
1754
+ ]
1755
+ self.v_buffer = [
1756
+ torch.zeros(
1757
+ (size + page_size, head_num, head_dim),
1758
+ dtype=dtype,
1759
+ device=device,
1760
+ )
1761
+ for _ in range(layer_num)
1762
+ ]
1369
1763
 
1370
- # [size, head_num, heavy_channel_num] for each layer
1371
- self.label_buffer = [
1372
- torch.zeros(
1373
- (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
1374
- )
1375
- for _ in range(layer_num)
1376
- ]
1764
+ # [size, head_num, heavy_channel_num] for each layer
1765
+ self.label_buffer = [
1766
+ torch.zeros(
1767
+ (size + 1, head_num, heavy_channel_num),
1768
+ dtype=dtype,
1769
+ device=device,
1770
+ )
1771
+ for _ in range(layer_num)
1772
+ ]
1377
1773
 
1378
1774
  def get_key_buffer(self, layer_id: int):
1379
1775
  return self.k_buffer[layer_id - self.start_layer]
@@ -1406,38 +1802,36 @@ class DoubleSparseTokenToKVPool(KVCache):
1406
1802
 
1407
1803
 
1408
1804
  @triton.jit
1409
- def copy_all_layer_kv_cache(
1805
+ def copy_all_layer_kv_cache_tiled(
1410
1806
  data_ptrs,
1411
1807
  strides,
1412
1808
  tgt_loc_ptr,
1413
1809
  src_loc_ptr,
1414
1810
  num_locs,
1415
1811
  num_locs_upper: tl.constexpr,
1812
+ BYTES_PER_TILE: tl.constexpr,
1416
1813
  ):
1417
- BLOCK_SIZE: tl.constexpr = 128
1418
-
1814
+ """2D tiled kernel. Safe for in-place copy."""
1419
1815
  bid = tl.program_id(0)
1816
+ tid = tl.program_id(1)
1817
+
1420
1818
  stride = tl.load(strides + bid)
1819
+ base_ptr = tl.load(data_ptrs + bid)
1820
+ base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8))
1421
1821
 
1422
- data_ptr = tl.load(data_ptrs + bid)
1423
- data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
1822
+ byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE)
1823
+ mask_byte = byte_off < stride
1824
+ tl.multiple_of(byte_off, 16)
1424
1825
 
1425
- num_locs_offset = tl.arange(0, num_locs_upper)
1426
- tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
1427
- src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
1826
+ loc_idx = tl.arange(0, num_locs_upper)
1827
+ mask_loc = loc_idx < num_locs
1428
1828
 
1429
- # NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
1430
- # because this copy is an inplace operation.
1829
+ src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0)
1830
+ tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0)
1431
1831
 
1432
- num_loop = tl.cdiv(stride, BLOCK_SIZE)
1433
- for i in range(num_loop):
1434
- copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
1435
- mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :]
1436
- value = tl.load(
1437
- data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
1438
- )
1439
- tl.store(
1440
- data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
1441
- value,
1442
- mask=mask,
1443
- )
1832
+ src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :]
1833
+ tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :]
1834
+
1835
+ mask = mask_loc[:, None] & mask_byte[None, :]
1836
+ vals = tl.load(src_ptr, mask=mask)
1837
+ tl.store(tgt_ptr, vals, mask=mask)