sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ from __future__ import annotations
17
+
18
+ from sglang.srt.layers.attention.nsa import index_buf_accessor
19
+ from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
16
20
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
17
21
 
18
22
  """
@@ -27,7 +31,7 @@ KVCache actually holds the physical kv cache.
27
31
  import abc
28
32
  import logging
29
33
  from contextlib import nullcontext
30
- from typing import Dict, List, Optional, Tuple, Union
34
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
31
35
 
32
36
  import numpy as np
33
37
  import torch
@@ -38,6 +42,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
38
42
  from sglang.srt.layers.radix_attention import RadixAttention
39
43
  from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
40
44
 
45
+ if TYPE_CHECKING:
46
+ from sglang.srt.managers.cache_controller import LayerDoneCounter
47
+
41
48
  logger = logging.getLogger(__name__)
42
49
 
43
50
  GB = 1024 * 1024 * 1024
@@ -47,6 +54,10 @@ if _is_npu:
47
54
  import torch_npu
48
55
 
49
56
 
57
+ def get_tensor_size_bytes(t: torch.Tensor):
58
+ return np.prod(t.shape) * t.dtype.itemsize
59
+
60
+
50
61
  class ReqToTokenPool:
51
62
  """A memory pool that maps a request to its token locations."""
52
63
 
@@ -97,6 +108,211 @@ class ReqToTokenPool:
97
108
  self.free_slots = list(range(self.size))
98
109
 
99
110
 
111
+ class MambaPool:
112
+ def __init__(
113
+ self,
114
+ size: int,
115
+ conv_dtype: torch.dtype,
116
+ ssm_dtype: torch.dtype,
117
+ num_mamba_layers: int,
118
+ conv_state_shape: Tuple[int, int],
119
+ temporal_state_shape: Tuple[int, int],
120
+ device: str,
121
+ speculative_num_draft_tokens: Optional[int] = None,
122
+ ):
123
+ conv_state = torch.zeros(
124
+ size=(num_mamba_layers, size + 1) + conv_state_shape,
125
+ dtype=conv_dtype,
126
+ device=device,
127
+ )
128
+ temporal_state = torch.zeros(
129
+ size=(num_mamba_layers, size + 1) + temporal_state_shape,
130
+ dtype=ssm_dtype,
131
+ device=device,
132
+ )
133
+ if speculative_num_draft_tokens is not None:
134
+ # Cache intermediate SSM states per draft token during target verify
135
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
136
+ intermediate_ssm_state_cache = torch.zeros(
137
+ size=(
138
+ num_mamba_layers,
139
+ size + 1,
140
+ speculative_num_draft_tokens,
141
+ temporal_state_shape[0],
142
+ temporal_state_shape[1],
143
+ temporal_state_shape[2],
144
+ ),
145
+ dtype=ssm_dtype,
146
+ device="cuda",
147
+ )
148
+ # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
149
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
150
+ intermediate_conv_window_cache = torch.zeros(
151
+ size=(
152
+ num_mamba_layers,
153
+ size + 1,
154
+ speculative_num_draft_tokens,
155
+ conv_state_shape[0],
156
+ conv_state_shape[1],
157
+ ),
158
+ dtype=conv_dtype,
159
+ device="cuda",
160
+ )
161
+ self.mamba_cache = (
162
+ conv_state,
163
+ temporal_state,
164
+ intermediate_ssm_state_cache,
165
+ intermediate_conv_window_cache,
166
+ )
167
+ logger.info(
168
+ f"Mamba Cache is allocated. "
169
+ f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
170
+ f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
171
+ f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
172
+ f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
173
+ )
174
+ else:
175
+ self.mamba_cache = (conv_state, temporal_state)
176
+ logger.info(
177
+ f"Mamba Cache is allocated. "
178
+ f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
179
+ f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
180
+ )
181
+ self.size = size
182
+ self.free_slots = list(range(size))
183
+ self.mem_usage = self.get_mamba_size() / GB
184
+
185
+ def get_mamba_params_all_layers(self):
186
+ return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
187
+
188
+ def get_mamba_params(self, layer_id: int):
189
+ return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
190
+
191
+ def get_mamba_size(self):
192
+ return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
193
+
194
+ def available_size(self):
195
+ return len(self.free_slots)
196
+
197
+ def alloc(self, need_size: int) -> Optional[List[int]]:
198
+ if need_size > len(self.free_slots):
199
+ return None
200
+
201
+ select_index = self.free_slots[:need_size]
202
+ self.free_slots = self.free_slots[need_size:]
203
+
204
+ return select_index
205
+
206
+ def free(self, free_index: Union[int, List[int]]):
207
+ if isinstance(free_index, (int,)):
208
+ self.free_slots.append(free_index)
209
+ else:
210
+ self.free_slots.extend(free_index)
211
+ self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
212
+
213
+ def clear(self):
214
+ self.free_slots = list(range(self.size))
215
+
216
+
217
+ class HybridReqToTokenPool(ReqToTokenPool):
218
+ """A memory pool that maps a request to its token locations."""
219
+
220
+ def __init__(
221
+ self,
222
+ size: int,
223
+ max_context_len: int,
224
+ device: str,
225
+ enable_memory_saver: bool,
226
+ conv_dtype: torch.dtype,
227
+ ssm_dtype: torch.dtype,
228
+ mamba_layers: List[int],
229
+ conv_state_shape: Tuple[int, int],
230
+ temporal_state_shape: Tuple[int, int],
231
+ speculative_num_draft_tokens: int,
232
+ ):
233
+ super().__init__(
234
+ size=size,
235
+ max_context_len=max_context_len,
236
+ device=device,
237
+ enable_memory_saver=enable_memory_saver,
238
+ )
239
+
240
+ self.mamba_pool = MambaPool(
241
+ size,
242
+ conv_dtype,
243
+ ssm_dtype,
244
+ len(mamba_layers),
245
+ conv_state_shape,
246
+ temporal_state_shape,
247
+ device,
248
+ speculative_num_draft_tokens,
249
+ )
250
+ self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
251
+
252
+ self.device = device
253
+ self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
254
+ size, dtype=torch.int32, device=self.device
255
+ )
256
+
257
+ self.rid_to_mamba_index_mapping: Dict[str, int] = {}
258
+ self.mamba_index_to_rid_mapping: Dict[int, str] = {}
259
+
260
+ # For chunk prefill req, we do not need to allocate mamba cache,
261
+ # We could use allocated mamba cache instead.
262
+ def alloc(
263
+ self, need_size: int, reqs: Optional[List["Req"]] = None
264
+ ) -> Optional[List[int]]:
265
+ select_index = super().alloc(need_size)
266
+ if select_index == None:
267
+ return None
268
+
269
+ mamba_index = []
270
+ for req in reqs:
271
+ rid = req.rid
272
+ if rid in self.rid_to_mamba_index_mapping:
273
+ mid = self.rid_to_mamba_index_mapping[rid]
274
+ elif (mid := self.mamba_pool.alloc(1)) is not None:
275
+ mid = mid[0]
276
+ self.rid_to_mamba_index_mapping[rid] = mid
277
+ self.mamba_index_to_rid_mapping[mid] = rid
278
+ mamba_index.append(mid)
279
+ assert len(select_index) == len(
280
+ mamba_index
281
+ ), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
282
+ self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
283
+ mamba_index, dtype=torch.int32, device=self.device
284
+ )
285
+ return select_index
286
+
287
+ def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
288
+ return self.req_index_to_mamba_index_mapping[req_indices]
289
+
290
+ def get_mamba_params(self, layer_id: int):
291
+ assert layer_id in self.mamba_map
292
+ return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
293
+
294
+ def get_mamba_params_all_layers(self):
295
+ return self.mamba_pool.get_mamba_params_all_layers()
296
+
297
+ # For chunk prefill, we can not free mamba cache, we need use it in the future
298
+ def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
299
+ super().free(free_index)
300
+ if free_mamba_cache:
301
+ mamba_index = self.req_index_to_mamba_index_mapping[free_index]
302
+ mamba_index_list = mamba_index.tolist()
303
+ if isinstance(mamba_index_list, int):
304
+ mamba_index_list = [mamba_index_list]
305
+ self.mamba_pool.free(mamba_index_list)
306
+ for mid in mamba_index_list:
307
+ rid = self.mamba_index_to_rid_mapping[mid]
308
+ self.mamba_index_to_rid_mapping.pop(mid)
309
+ self.rid_to_mamba_index_mapping.pop(rid)
310
+
311
+ def clear(self):
312
+ super().clear()
313
+ self.mamba_pool.clear()
314
+
315
+
100
316
  class KVCache(abc.ABC):
101
317
  @abc.abstractmethod
102
318
  def __init__(
@@ -130,6 +346,29 @@ class KVCache(abc.ABC):
130
346
  # used for chunked cpu-offloading
131
347
  self.cpu_offloading_chunk_size = 8192
132
348
 
349
+ # default state for optional layer-wise transfer control
350
+ self.layer_transfer_counter = None
351
+
352
+ def _finalize_allocation_log(self, num_tokens: int):
353
+ """Common logging and mem_usage computation for KV cache allocation.
354
+ Supports both tuple (K, V) size returns and single KV size returns.
355
+ """
356
+ kv_size_bytes = self.get_kv_size_bytes()
357
+ if isinstance(kv_size_bytes, tuple):
358
+ k_size, v_size = kv_size_bytes
359
+ k_size_GB = k_size / GB
360
+ v_size_GB = v_size / GB
361
+ logger.info(
362
+ f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB"
363
+ )
364
+ self.mem_usage = k_size_GB + v_size_GB
365
+ else:
366
+ kv_size_GB = kv_size_bytes / GB
367
+ logger.info(
368
+ f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB"
369
+ )
370
+ self.mem_usage = kv_size_GB
371
+
133
372
  @abc.abstractmethod
134
373
  def get_key_buffer(self, layer_id: int) -> torch.Tensor:
135
374
  raise NotImplementedError()
@@ -152,7 +391,7 @@ class KVCache(abc.ABC):
152
391
  ) -> None:
153
392
  raise NotImplementedError()
154
393
 
155
- def register_layer_transfer_counter(self, layer_transfer_counter):
394
+ def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
156
395
  self.layer_transfer_counter = layer_transfer_counter
157
396
 
158
397
  def get_cpu_copy(self, indices):
@@ -205,15 +444,9 @@ class MHATokenToKVPool(KVCache):
205
444
 
206
445
  self._create_buffers()
207
446
 
208
- self.layer_transfer_counter = None
209
447
  self.device_module = torch.get_device_module(self.device)
210
448
  self.alt_stream = self.device_module.Stream() if _is_cuda else None
211
-
212
- k_size, v_size = self.get_kv_size_bytes()
213
- logger.info(
214
- f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
215
- )
216
- self.mem_usage = (k_size + v_size) / GB
449
+ self._finalize_allocation_log(size)
217
450
 
218
451
  def _create_buffers(self):
219
452
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
@@ -269,10 +502,10 @@ class MHATokenToKVPool(KVCache):
269
502
  assert hasattr(self, "v_buffer")
270
503
  k_size_bytes = 0
271
504
  for k_cache in self.k_buffer:
272
- k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
505
+ k_size_bytes += get_tensor_size_bytes(k_cache)
273
506
  v_size_bytes = 0
274
507
  for v_cache in self.v_buffer:
275
- v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
508
+ v_size_bytes += get_tensor_size_bytes(v_cache)
276
509
  return k_size_bytes, v_size_bytes
277
510
 
278
511
  # for disagg
@@ -352,7 +585,6 @@ class MHATokenToKVPool(KVCache):
352
585
  # same applies to get_value_buffer and get_kv_buffer
353
586
  if self.layer_transfer_counter is not None:
354
587
  self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
355
-
356
588
  return self._get_key_buffer(layer_id)
357
589
 
358
590
  def _get_value_buffer(self, layer_id: int):
@@ -420,41 +652,31 @@ class MHATokenToKVPool(KVCache):
420
652
  )
421
653
 
422
654
 
423
- class SWAKVPool(KVCache):
424
- """KV cache with separate pools for full and SWA attention layers."""
655
+ class HybridLinearKVPool(KVCache):
656
+ """KV cache with separate pools for full and linear attention layers."""
425
657
 
426
658
  def __init__(
427
659
  self,
428
660
  size: int,
429
- size_swa: int,
430
661
  dtype: torch.dtype,
662
+ page_size: int,
431
663
  head_num: int,
432
664
  head_dim: int,
433
- swa_attention_layer_ids: List[int],
434
665
  full_attention_layer_ids: List[int],
435
666
  enable_kvcache_transpose: bool,
436
667
  device: str,
437
668
  ):
438
669
  self.size = size
439
- self.size_swa = size_swa
440
670
  self.dtype = dtype
441
671
  self.device = device
442
- self.swa_layer_nums = len(swa_attention_layer_ids)
443
672
  self.full_layer_nums = len(full_attention_layer_ids)
444
- self.page_size = 1
673
+ self.page_size = page_size
445
674
  # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
446
675
  assert not enable_kvcache_transpose
447
- TokenToKVPoolClass = MHATokenToKVPool
448
- self.swa_kv_pool = TokenToKVPoolClass(
449
- size=size_swa,
450
- page_size=self.page_size,
451
- dtype=dtype,
452
- head_num=head_num,
453
- head_dim=head_dim,
454
- layer_num=self.swa_layer_nums,
455
- device=device,
456
- enable_memory_saver=False,
457
- )
676
+ if _is_npu:
677
+ TokenToKVPoolClass = AscendTokenToKVPool
678
+ else:
679
+ TokenToKVPoolClass = MHATokenToKVPool
458
680
  self.full_kv_pool = TokenToKVPoolClass(
459
681
  size=size,
460
682
  page_size=self.page_size,
@@ -465,6 +687,93 @@ class SWAKVPool(KVCache):
465
687
  device=device,
466
688
  enable_memory_saver=False,
467
689
  )
690
+ self.full_attention_layer_id_mapping = {
691
+ id: i for i, id in enumerate(full_attention_layer_ids)
692
+ }
693
+ k_size, v_size = self.get_kv_size_bytes()
694
+ self.mem_usage = (k_size + v_size) / GB
695
+
696
+ def get_kv_size_bytes(self):
697
+ return self.full_kv_pool.get_kv_size_bytes()
698
+
699
+ def get_contiguous_buf_infos(self):
700
+ return self.full_kv_pool.get_contiguous_buf_infos()
701
+
702
+ def _transfer_full_attention_id(self, layer_id: int):
703
+ if layer_id not in self.full_attention_layer_id_mapping:
704
+ raise ValueError(
705
+ f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
706
+ )
707
+ return self.full_attention_layer_id_mapping[layer_id]
708
+
709
+ def get_key_buffer(self, layer_id: int):
710
+ layer_id = self._transfer_full_attention_id(layer_id)
711
+ return self.full_kv_pool.get_key_buffer(layer_id)
712
+
713
+ def get_value_buffer(self, layer_id: int):
714
+ layer_id = self._transfer_full_attention_id(layer_id)
715
+ return self.full_kv_pool.get_value_buffer(layer_id)
716
+
717
+ def get_kv_buffer(self, layer_id: int):
718
+ layer_id = self._transfer_full_attention_id(layer_id)
719
+ return self.full_kv_pool.get_kv_buffer(layer_id)
720
+
721
+ def set_kv_buffer(
722
+ self,
723
+ layer: RadixAttention,
724
+ loc: torch.Tensor,
725
+ cache_k: torch.Tensor,
726
+ cache_v: torch.Tensor,
727
+ k_scale: float = 1.0,
728
+ v_scale: float = 1.0,
729
+ ):
730
+ layer_id = self._transfer_full_attention_id(layer.layer_id)
731
+ self.full_kv_pool.set_kv_buffer(
732
+ None,
733
+ loc,
734
+ cache_k,
735
+ cache_v,
736
+ k_scale,
737
+ v_scale,
738
+ layer_id_override=layer_id,
739
+ )
740
+
741
+ def get_v_head_dim(self):
742
+ return self.full_kv_pool.get_value_buffer(0).shape[-1]
743
+
744
+
745
+ class SWAKVPool(KVCache):
746
+ """KV cache with separate pools for full and SWA attention layers."""
747
+
748
+ def __init__(
749
+ self,
750
+ size: int,
751
+ size_swa: int,
752
+ swa_attention_layer_ids: List[int],
753
+ full_attention_layer_ids: List[int],
754
+ enable_kvcache_transpose: bool,
755
+ token_to_kv_pool_class: KVCache = MHATokenToKVPool,
756
+ **kwargs,
757
+ ):
758
+ self.size = size
759
+ self.size_swa = size_swa
760
+ self.swa_layer_nums = len(swa_attention_layer_ids)
761
+ self.full_layer_nums = len(full_attention_layer_ids)
762
+ kwargs["page_size"] = 1
763
+ kwargs["enable_memory_saver"] = False
764
+ # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
765
+ assert not enable_kvcache_transpose
766
+
767
+ self.swa_kv_pool = token_to_kv_pool_class(
768
+ size=size_swa,
769
+ layer_num=self.swa_layer_nums,
770
+ **kwargs,
771
+ )
772
+ self.full_kv_pool = token_to_kv_pool_class(
773
+ size=size,
774
+ layer_num=self.full_layer_nums,
775
+ **kwargs,
776
+ )
468
777
  self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
469
778
  for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
470
779
  self.layers_mapping[global_layer_id] = (full_attn_layer_id, False)
@@ -613,8 +922,12 @@ class AscendTokenToKVPool(MHATokenToKVPool):
613
922
  cache_v: torch.Tensor,
614
923
  k_scale: Optional[float] = None,
615
924
  v_scale: Optional[float] = None,
925
+ layer_id_override: Optional[int] = None,
616
926
  ):
617
- layer_id = layer.layer_id
927
+ if layer_id_override is not None:
928
+ layer_id = layer_id_override
929
+ else:
930
+ layer_id = layer.layer_id
618
931
  if cache_k.dtype != self.dtype:
619
932
  if k_scale is not None:
620
933
  cache_k.div_(k_scale)
@@ -719,6 +1032,8 @@ class MLATokenToKVPool(KVCache):
719
1032
  enable_memory_saver: bool,
720
1033
  start_layer: Optional[int] = None,
721
1034
  end_layer: Optional[int] = None,
1035
+ use_nsa: bool = False,
1036
+ override_kv_cache_dim: Optional[int] = None,
722
1037
  ):
723
1038
  super().__init__(
724
1039
  size,
@@ -733,6 +1048,14 @@ class MLATokenToKVPool(KVCache):
733
1048
 
734
1049
  self.kv_lora_rank = kv_lora_rank
735
1050
  self.qk_rope_head_dim = qk_rope_head_dim
1051
+ self.use_nsa = use_nsa
1052
+ self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn
1053
+ # TODO do not hardcode
1054
+ self.kv_cache_dim = (
1055
+ 656
1056
+ if self.use_nsa and self.nsa_kv_cache_store_fp8
1057
+ else (kv_lora_rank + qk_rope_head_dim)
1058
+ )
736
1059
 
737
1060
  # for disagg with nvlink
738
1061
  self.enable_custom_mem_pool = get_bool_env_var(
@@ -756,7 +1079,7 @@ class MLATokenToKVPool(KVCache):
756
1079
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
757
1080
  self.kv_buffer = [
758
1081
  torch.zeros(
759
- (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
1082
+ (size + page_size, 1, self.kv_cache_dim),
760
1083
  dtype=self.store_dtype,
761
1084
  device=device,
762
1085
  )
@@ -768,19 +1091,13 @@ class MLATokenToKVPool(KVCache):
768
1091
  dtype=torch.uint64,
769
1092
  device=self.device,
770
1093
  )
771
- self.layer_transfer_counter = None
772
-
773
- kv_size = self.get_kv_size_bytes()
774
- logger.info(
775
- f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
776
- )
777
- self.mem_usage = kv_size / GB
1094
+ self._finalize_allocation_log(size)
778
1095
 
779
1096
  def get_kv_size_bytes(self):
780
1097
  assert hasattr(self, "kv_buffer")
781
1098
  kv_size_bytes = 0
782
1099
  for kv_cache in self.kv_buffer:
783
- kv_size_bytes += np.prod(kv_cache.shape) * kv_cache.dtype.itemsize
1100
+ kv_size_bytes += get_tensor_size_bytes(kv_cache)
784
1101
  return kv_size_bytes
785
1102
 
786
1103
  # for disagg
@@ -825,6 +1142,7 @@ class MLATokenToKVPool(KVCache):
825
1142
  cache_v: torch.Tensor,
826
1143
  ):
827
1144
  layer_id = layer.layer_id
1145
+ assert not (self.use_nsa and self.nsa_kv_cache_store_fp8)
828
1146
  if cache_k.dtype != self.dtype:
829
1147
  cache_k = cache_k.to(self.dtype)
830
1148
  if self.store_dtype != self.dtype:
@@ -842,16 +1160,28 @@ class MLATokenToKVPool(KVCache):
842
1160
  cache_k_rope: torch.Tensor,
843
1161
  ):
844
1162
  layer_id = layer.layer_id
845
- if cache_k_nope.dtype != self.dtype:
846
- cache_k_nope = cache_k_nope.to(self.dtype)
847
- cache_k_rope = cache_k_rope.to(self.dtype)
848
- if self.store_dtype != self.dtype:
849
- cache_k_nope = cache_k_nope.view(self.store_dtype)
850
- cache_k_rope = cache_k_rope.view(self.store_dtype)
851
1163
 
852
- set_mla_kv_buffer_triton(
853
- self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
854
- )
1164
+ if self.use_nsa and self.nsa_kv_cache_store_fp8:
1165
+ # original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here
1166
+ # TODO no need to cat
1167
+ cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1)
1168
+ cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1)
1169
+ cache_k = cache_k.view(self.store_dtype)
1170
+ self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
1171
+ else:
1172
+ if cache_k_nope.dtype != self.dtype:
1173
+ cache_k_nope = cache_k_nope.to(self.dtype)
1174
+ cache_k_rope = cache_k_rope.to(self.dtype)
1175
+ if self.store_dtype != self.dtype:
1176
+ cache_k_nope = cache_k_nope.view(self.store_dtype)
1177
+ cache_k_rope = cache_k_rope.view(self.store_dtype)
1178
+
1179
+ set_mla_kv_buffer_triton(
1180
+ self.kv_buffer[layer_id - self.start_layer],
1181
+ loc,
1182
+ cache_k_nope,
1183
+ cache_k_rope,
1184
+ )
855
1185
 
856
1186
  def get_cpu_copy(self, indices):
857
1187
  torch.cuda.synchronize()
@@ -881,6 +1211,103 @@ class MLATokenToKVPool(KVCache):
881
1211
  torch.cuda.synchronize()
882
1212
 
883
1213
 
1214
+ class NSATokenToKVPool(MLATokenToKVPool):
1215
+ def __init__(
1216
+ self,
1217
+ size: int,
1218
+ page_size: int,
1219
+ kv_lora_rank: int,
1220
+ dtype: torch.dtype,
1221
+ qk_rope_head_dim: int,
1222
+ layer_num: int,
1223
+ device: str,
1224
+ index_head_dim: int,
1225
+ enable_memory_saver: bool,
1226
+ start_layer: Optional[int] = None,
1227
+ end_layer: Optional[int] = None,
1228
+ ):
1229
+ super().__init__(
1230
+ size,
1231
+ page_size,
1232
+ dtype,
1233
+ kv_lora_rank,
1234
+ qk_rope_head_dim,
1235
+ layer_num,
1236
+ device,
1237
+ enable_memory_saver,
1238
+ start_layer,
1239
+ end_layer,
1240
+ use_nsa=True,
1241
+ )
1242
+ # self.index_k_dtype = torch.float8_e4m3fn
1243
+ # self.index_k_scale_dtype = torch.float32
1244
+ self.index_head_dim = index_head_dim
1245
+ # num head == 1 and head dim == 128 for index_k in NSA
1246
+ assert index_head_dim == 128
1247
+
1248
+ self.quant_block_size = 128
1249
+
1250
+ assert self.page_size == 64
1251
+ self.index_k_with_scale_buffer = [
1252
+ torch.zeros(
1253
+ # Layout:
1254
+ # ref: test_attention.py :: kv_cache_cast_to_fp8
1255
+ # shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
1256
+ # data: for page i,
1257
+ # * buf[i, :page_size * head_dim] for fp8 data
1258
+ # * buf[i, page_size * head_dim:].view(float32) for scale
1259
+ (
1260
+ (size + page_size + 1) // self.page_size,
1261
+ self.page_size
1262
+ * (index_head_dim + index_head_dim // self.quant_block_size * 4),
1263
+ ),
1264
+ dtype=torch.uint8,
1265
+ device=device,
1266
+ )
1267
+ for _ in range(layer_num)
1268
+ ]
1269
+
1270
+ def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
1271
+ if self.layer_transfer_counter is not None:
1272
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
1273
+ return self.index_k_with_scale_buffer[layer_id - self.start_layer]
1274
+
1275
+ def get_index_k_continuous(
1276
+ self,
1277
+ layer_id: int,
1278
+ seq_len: int,
1279
+ page_indices: torch.Tensor,
1280
+ ):
1281
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
1282
+ return index_buf_accessor.GetK.execute(
1283
+ self, buf, seq_len=seq_len, page_indices=page_indices
1284
+ )
1285
+
1286
+ def get_index_k_scale_continuous(
1287
+ self,
1288
+ layer_id: int,
1289
+ seq_len: int,
1290
+ page_indices: torch.Tensor,
1291
+ ):
1292
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
1293
+ return index_buf_accessor.GetS.execute(
1294
+ self, buf, seq_len=seq_len, page_indices=page_indices
1295
+ )
1296
+
1297
+ # TODO rename later (currently use diff name to avoid confusion)
1298
+ def set_index_k_and_scale_buffer(
1299
+ self,
1300
+ layer_id: int,
1301
+ loc: torch.Tensor,
1302
+ index_k: torch.Tensor,
1303
+ index_k_scale: torch.Tensor,
1304
+ ) -> None:
1305
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
1306
+ index_buf_accessor.SetKAndS.execute(
1307
+ pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
1308
+ )
1309
+
1310
+
884
1311
  class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
885
1312
  def __init__(
886
1313
  self,
@@ -889,6 +1316,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
889
1316
  dtype: torch.dtype,
890
1317
  kv_lora_rank: int,
891
1318
  qk_rope_head_dim: int,
1319
+ index_head_dim: Optional[int],
892
1320
  layer_num: int,
893
1321
  device: str,
894
1322
  enable_memory_saver: bool,
@@ -908,6 +1336,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
908
1336
 
909
1337
  self.kv_lora_rank = kv_lora_rank
910
1338
  self.qk_rope_head_dim = qk_rope_head_dim
1339
+ self.index_head_dim = index_head_dim
911
1340
 
912
1341
  self.custom_mem_pool = None
913
1342
 
@@ -935,23 +1364,33 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
935
1364
  dtype=self.store_dtype,
936
1365
  device=self.device,
937
1366
  )
1367
+ if self.index_head_dim is not None:
1368
+ self.index_k_buffer = torch.zeros(
1369
+ (
1370
+ layer_num,
1371
+ self.size // self.page_size + 1,
1372
+ self.page_size,
1373
+ 1,
1374
+ self.index_head_dim,
1375
+ ),
1376
+ dtype=self.store_dtype,
1377
+ device=self.device,
1378
+ )
938
1379
 
939
- self.layer_transfer_counter = None
940
-
941
- kv_size = self.get_kv_size_bytes()
942
- logger.info(
943
- f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
944
- )
945
- self.mem_usage = kv_size / GB
1380
+ self._finalize_allocation_log(size)
946
1381
 
947
1382
  def get_kv_size_bytes(self):
948
1383
  assert hasattr(self, "k_buffer")
949
1384
  assert hasattr(self, "v_buffer")
950
1385
  kv_size_bytes = 0
951
1386
  for k_cache in self.k_buffer:
952
- kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
1387
+ kv_size_bytes += get_tensor_size_bytes(k_cache)
953
1388
  for v_cache in self.v_buffer:
954
- kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
1389
+ kv_size_bytes += get_tensor_size_bytes(v_cache)
1390
+ if self.index_head_dim is not None:
1391
+ assert hasattr(self, "index_k_buffer")
1392
+ for index_k_cache in self.index_k_buffer:
1393
+ kv_size_bytes += get_tensor_size_bytes(index_k_cache)
955
1394
  return kv_size_bytes
956
1395
 
957
1396
  def get_kv_buffer(self, layer_id: int):
@@ -978,6 +1417,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
978
1417
  return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
979
1418
  return self.v_buffer[layer_id - self.start_layer]
980
1419
 
1420
+ def get_index_k_buffer(self, layer_id: int):
1421
+ if self.layer_transfer_counter is not None:
1422
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
1423
+
1424
+ if self.store_dtype != self.dtype:
1425
+ return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)
1426
+ return self.index_k_buffer[layer_id - self.start_layer]
1427
+
981
1428
  # for disagg
982
1429
  def get_contiguous_buf_infos(self):
983
1430
  # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
@@ -990,6 +1437,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
990
1437
  kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
991
1438
  self.v_buffer[i][0].nbytes for i in range(self.layer_num)
992
1439
  ]
1440
+ if self.index_head_dim is not None:
1441
+ kv_data_ptrs += [
1442
+ self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)
1443
+ ]
1444
+ kv_data_lens += [
1445
+ self.index_k_buffer[i].nbytes for i in range(self.layer_num)
1446
+ ]
1447
+ kv_item_lens += [
1448
+ self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)
1449
+ ]
993
1450
  return kv_data_ptrs, kv_data_lens, kv_item_lens
994
1451
 
995
1452
  def set_kv_buffer(
@@ -1026,6 +1483,26 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1026
1483
  cache_v.view(-1, 1, self.qk_rope_head_dim),
1027
1484
  )
1028
1485
 
1486
+ def set_index_k_buffer(
1487
+ self,
1488
+ layer_id: int,
1489
+ loc: torch.Tensor,
1490
+ index_k: torch.Tensor,
1491
+ ):
1492
+ if index_k.dtype != self.dtype:
1493
+ index_k = index_k.to(self.dtype)
1494
+
1495
+ if self.store_dtype != self.dtype:
1496
+ index_k = index_k.view(self.store_dtype)
1497
+
1498
+ torch_npu.npu_scatter_nd_update_(
1499
+ self.index_k_buffer[layer_id - self.start_layer].view(
1500
+ -1, 1, self.index_head_dim
1501
+ ),
1502
+ loc.view(-1, 1),
1503
+ index_k.view(-1, 1, self.index_head_dim),
1504
+ )
1505
+
1029
1506
 
1030
1507
  class DoubleSparseTokenToKVPool(KVCache):
1031
1508
  def __init__(