sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,8 @@ from typing import Any, List, Optional
7
7
 
8
8
  import torch
9
9
 
10
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
11
+
10
12
  logger = logging.getLogger(__name__)
11
13
 
12
14
 
@@ -32,15 +34,46 @@ class HiCacheStorageConfig:
32
34
  extra_config: Optional[dict] = None
33
35
 
34
36
 
37
+ @dataclass
38
+ class HiCacheStorageExtraInfo:
39
+ extra_info: Optional[dict] = None
40
+
41
+
35
42
  class HiCacheStorage(ABC):
36
43
  """
37
44
  HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
38
45
  It abstracts the underlying storage mechanism, allowing different implementations to be used.
39
46
  """
40
47
 
41
- # todo, potentially pass model and TP configs into storage backend
42
48
  # todo, the page size of storage backend does not have to be the same as the same as host memory pool
43
49
 
50
+ def register_mem_pool_host(self, mem_pool_host: HostKVCache):
51
+ self.mem_pool_host = mem_pool_host
52
+
53
+ def batch_get_v1(
54
+ self,
55
+ keys: List[str],
56
+ host_indices: torch.Tensor,
57
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
58
+ ) -> List[bool]:
59
+ """
60
+ Retrieve values for multiple keys.
61
+ Returns a list of tensors or None for each key.
62
+ """
63
+ pass
64
+
65
+ def batch_set_v1(
66
+ self,
67
+ keys: List[str],
68
+ host_indices: torch.Tensor,
69
+ extra_info: Optional[HiCacheStorageExtraInfo] = None,
70
+ ) -> List[bool]:
71
+ """
72
+ Retrieve values for multiple keys.
73
+ Returns a list of tensors or None for each key.
74
+ """
75
+ pass
76
+
44
77
  @abstractmethod
45
78
  def get(
46
79
  self,
@@ -54,6 +87,7 @@ class HiCacheStorage(ABC):
54
87
  """
55
88
  pass
56
89
 
90
+ # TODO: Deprecate
57
91
  @abstractmethod
58
92
  def batch_get(
59
93
  self,
@@ -81,6 +115,7 @@ class HiCacheStorage(ABC):
81
115
  """
82
116
  pass
83
117
 
118
+ # TODO: Deprecate
84
119
  @abstractmethod
85
120
  def batch_set(
86
121
  self,
@@ -103,6 +138,7 @@ class HiCacheStorage(ABC):
103
138
  """
104
139
  pass
105
140
 
141
+ # TODO: Use a finer-grained return type (e.g., List[bool])
106
142
  def batch_exists(self, keys: List[str]) -> int:
107
143
  """
108
144
  Check if the keys exist in the storage.
@@ -114,6 +150,9 @@ class HiCacheStorage(ABC):
114
150
  return i
115
151
  return len(keys)
116
152
 
153
+ def clear(self) -> None:
154
+ pass
155
+
117
156
  def get_stats(self):
118
157
  return None
119
158
 
@@ -1,8 +1,8 @@
1
1
  import heapq
2
+ import json
2
3
  import logging
3
4
  import threading
4
5
  import time
5
- from queue import Queue
6
6
  from typing import List, Optional
7
7
 
8
8
  import torch
@@ -19,7 +19,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
19
19
  MHATokenToKVPoolHost,
20
20
  MLATokenToKVPoolHost,
21
21
  )
22
- from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
22
+ from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
23
23
  from sglang.srt.metrics.collector import StorageMetricsCollector
24
24
 
25
25
  logger = logging.getLogger(__name__)
@@ -39,17 +39,19 @@ class HiRadixCache(RadixCache):
39
39
  hicache_io_backend: str,
40
40
  hicache_mem_layout: str,
41
41
  enable_metrics: bool,
42
+ eviction_policy: str = "lru",
42
43
  hicache_storage_backend: Optional[str] = None,
43
44
  hicache_storage_prefetch_policy: Optional[str] = "best_effort",
44
45
  model_name: Optional[str] = None,
45
46
  storage_backend_extra_config: Optional[str] = None,
47
+ is_eagle: bool = False,
46
48
  ):
47
49
 
48
50
  if hicache_io_backend == "direct":
49
51
  if hicache_mem_layout == "page_first":
50
- hicache_mem_layout = "layer_first"
52
+ hicache_mem_layout = "page_first_direct"
51
53
  logger.warning(
52
- "Page first layout is not supported with direct IO backend, switching to layer first layout"
54
+ "Page first layout is not supported with direct IO backend, switching to page first direct layout"
53
55
  )
54
56
 
55
57
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
@@ -77,9 +79,19 @@ class HiRadixCache(RadixCache):
77
79
  self.enable_storage = hicache_storage_backend is not None
78
80
  self.enable_storage_metrics = self.enable_storage and enable_metrics
79
81
 
80
- # todo: customizable storage prefetch threshold and timeout
81
- self.prefetch_threshold = 256
82
- self.prefetch_timeout = 3 # seconds
82
+ (
83
+ extra_config,
84
+ prefetch_threshold,
85
+ prefetch_timeout_base,
86
+ prefetch_timeout_per_ki_token,
87
+ ) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
88
+ self.prefetch_threshold = prefetch_threshold
89
+ self.prefetch_timeout_base = prefetch_timeout_base
90
+ self.prefetch_timeout_per_page = (
91
+ page_size / 1024 * prefetch_timeout_per_ki_token
92
+ )
93
+ # TODO: support more timeout check functions
94
+ self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
83
95
  self.prefetch_stop_policy = hicache_storage_prefetch_policy
84
96
 
85
97
  self.load_cache_event = threading.Event()
@@ -94,7 +106,7 @@ class HiRadixCache(RadixCache):
94
106
  storage_backend=hicache_storage_backend,
95
107
  prefetch_threshold=self.prefetch_threshold,
96
108
  model_name=model_name,
97
- storage_backend_extra_config=storage_backend_extra_config,
109
+ storage_backend_extra_config=extra_config,
98
110
  )
99
111
  if self.enable_storage_metrics:
100
112
  # TODO: support pp
@@ -117,8 +129,61 @@ class HiRadixCache(RadixCache):
117
129
  1 if hicache_write_policy == "write_through" else 2
118
130
  )
119
131
  self.load_back_threshold = 10
132
+
120
133
  super().__init__(
121
- req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
134
+ req_to_token_pool,
135
+ token_to_kv_pool_allocator,
136
+ page_size,
137
+ disable=False,
138
+ eviction_policy=eviction_policy,
139
+ is_eagle=is_eagle,
140
+ )
141
+
142
+ def _parse_storage_backend_extra_config(
143
+ self, storage_backend_extra_config: Optional[str]
144
+ ):
145
+ """
146
+ Parse storage backend extra config JSON and extract specific parameters.
147
+
148
+ Args:
149
+ storage_backend_extra_config: JSON string containing extra configuration
150
+
151
+ Returns:
152
+ tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
153
+ """
154
+ # Parse extra config JSON if provided
155
+ extra_config = {}
156
+ if storage_backend_extra_config:
157
+ try:
158
+ extra_config = json.loads(storage_backend_extra_config)
159
+ except Exception as e:
160
+ logger.error(f"Invalid backend extra config JSON: {e}")
161
+ raise e
162
+
163
+ prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens
164
+ prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds
165
+ prefetch_timeout_per_ki_token = extra_config.pop(
166
+ "prefetch_timeout_per_ki_token", 0.25
167
+ ) # seconds per 1024 tokens
168
+
169
+ if not isinstance(prefetch_threshold, int):
170
+ raise ValueError(
171
+ f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}"
172
+ )
173
+ if not isinstance(prefetch_timeout_base, (int, float)):
174
+ raise ValueError(
175
+ f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}"
176
+ )
177
+ if not isinstance(prefetch_timeout_per_ki_token, (int, float)):
178
+ raise ValueError(
179
+ f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}"
180
+ )
181
+
182
+ return (
183
+ extra_config,
184
+ prefetch_threshold,
185
+ float(prefetch_timeout_base),
186
+ float(prefetch_timeout_per_ki_token),
122
187
  )
123
188
 
124
189
  def reset(self):
@@ -258,12 +323,15 @@ class HiRadixCache(RadixCache):
258
323
 
259
324
  def evict(self, num_tokens: int):
260
325
  leaves = self._collect_leaves_device()
261
- heapq.heapify(leaves)
326
+ eviction_heap = [
327
+ (self.eviction_strategy.get_priority(node), node) for node in leaves
328
+ ]
329
+ heapq.heapify(eviction_heap)
262
330
 
263
331
  num_evicted = 0
264
332
  write_back_nodes = []
265
- while num_evicted < num_tokens and len(leaves):
266
- x = heapq.heappop(leaves)
333
+ while num_evicted < num_tokens and len(eviction_heap):
334
+ _priority, x = heapq.heappop(eviction_heap)
267
335
 
268
336
  if x.lock_ref > 0:
269
337
  continue
@@ -285,7 +353,8 @@ class HiRadixCache(RadixCache):
285
353
  break
286
354
  else:
287
355
  # all children are evicted or no children
288
- heapq.heappush(leaves, x.parent)
356
+ new_priority = self.eviction_strategy.get_priority(x.parent)
357
+ heapq.heappush(eviction_heap, (new_priority, x.parent))
289
358
 
290
359
  if self.cache_controller.write_policy == "write_back":
291
360
  self.writing_check(write_back=True)
@@ -295,7 +364,7 @@ class HiRadixCache(RadixCache):
295
364
 
296
365
  def _evict_backuped(self, node: TreeNode):
297
366
  # evict a node already written to host
298
- num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
367
+ num_evicted = self.cache_controller.evict_device(node.value)
299
368
  assert num_evicted > 0
300
369
  self.evictable_size_ -= num_evicted
301
370
  node.value = None
@@ -310,11 +379,14 @@ class HiRadixCache(RadixCache):
310
379
 
311
380
  def evict_host(self, num_tokens: int):
312
381
  leaves = self._collect_leaves()
313
- heapq.heapify(leaves)
382
+ eviction_heap = [
383
+ (self.eviction_strategy.get_priority(node), node) for node in leaves
384
+ ]
385
+ heapq.heapify(eviction_heap)
314
386
 
315
387
  num_evicted = 0
316
- while num_evicted < num_tokens and len(leaves):
317
- x = heapq.heappop(leaves)
388
+ while num_evicted < num_tokens and len(eviction_heap):
389
+ _priority, x = heapq.heappop(eviction_heap)
318
390
  if x == self.root_node:
319
391
  break
320
392
  # only evict the host value of evicted nodes
@@ -333,7 +405,8 @@ class HiRadixCache(RadixCache):
333
405
  del x.parent.children[k]
334
406
 
335
407
  if len(x.parent.children) == 0 and x.parent.evicted:
336
- heapq.heappush(leaves, x.parent)
408
+ new_priority = self.eviction_strategy.get_priority(x.parent)
409
+ heapq.heappush(eviction_heap, (new_priority, x.parent))
337
410
 
338
411
  def load_back(
339
412
  self, node: TreeNode, mem_quota: Optional[int] = None
@@ -476,6 +549,15 @@ class HiRadixCache(RadixCache):
476
549
  host_indices = torch.cat(host_indices_list, dim=0)
477
550
  cc.mem_pool_host.free(host_indices)
478
551
 
552
+ # Timeout is linearly increasing with the number of pages
553
+ def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation):
554
+ # If hash_value has not been computed in timeout_base seconds, terminate it.
555
+ return (
556
+ time.monotonic() - operation.start_time
557
+ > self.prefetch_timeout_base
558
+ + len(operation.hash_value) * self.prefetch_timeout_per_page
559
+ )
560
+
479
561
  def can_terminate_prefetch(self, operation: PrefetchOperation):
480
562
  can_terminate = True
481
563
 
@@ -492,9 +574,7 @@ class HiRadixCache(RadixCache):
492
574
  if self.prefetch_stop_policy == "wait_complete":
493
575
  can_terminate = completed
494
576
  elif self.prefetch_stop_policy == "timeout":
495
- can_terminate = completed or (
496
- time.monotonic() - operation.start_time > self.prefetch_timeout
497
- )
577
+ can_terminate = completed or self.is_prefetch_timeout(operation)
498
578
  else:
499
579
  # unknown prefetch stop policy, just return True
500
580
  return True
@@ -556,12 +636,12 @@ class HiRadixCache(RadixCache):
556
636
  written_indices = host_indices[:min_completed_tokens]
557
637
  matched_length = self._insert_helper_host(
558
638
  last_host_node,
559
- fetched_token_ids,
639
+ RadixKey(
640
+ token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
641
+ ),
560
642
  written_indices,
561
643
  hash_value[: min_completed_tokens // self.page_size],
562
644
  )
563
- if len(written_indices):
564
- self.cache_controller.mem_pool_host.update_prefetch(written_indices)
565
645
 
566
646
  self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
567
647
  self.cache_controller.append_host_mem_release(
@@ -578,8 +658,9 @@ class HiRadixCache(RadixCache):
578
658
 
579
659
  return True
580
660
 
581
- def match_prefix(self, key: List[int], **kwargs):
661
+ def match_prefix(self, key: RadixKey, **kwargs):
582
662
  empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
663
+ key.token_ids = self.key_convert_fn(key.token_ids)
583
664
  if self.disable or len(key) == 0:
584
665
  return MatchResult(
585
666
  device_indices=empty_value,
@@ -652,7 +733,9 @@ class HiRadixCache(RadixCache):
652
733
  )
653
734
  self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
654
735
 
655
- def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
736
+ def _insert_helper_host(
737
+ self, node: TreeNode, key: RadixKey, host_value, hash_value
738
+ ):
656
739
  node.last_access_time = time.monotonic()
657
740
  if len(key) == 0:
658
741
  return 0
@@ -686,7 +769,7 @@ class HiRadixCache(RadixCache):
686
769
  node.children[child_key] = new_node
687
770
  return matched_length
688
771
 
689
- def _match_prefix_helper(self, node: TreeNode, key: List):
772
+ def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
690
773
  node.last_access_time = time.monotonic()
691
774
  child_key = self.get_child_key_fn(key)
692
775
  value = []
@@ -712,7 +795,7 @@ class HiRadixCache(RadixCache):
712
795
 
713
796
  return value, node
714
797
 
715
- def _split_node(self, key, child: TreeNode, split_len: int):
798
+ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
716
799
  # child node split into new_node -> child
717
800
  new_node = TreeNode()
718
801
  new_node.children = {self.get_child_key_fn(key[split_len:]): child}
@@ -739,10 +822,16 @@ class HiRadixCache(RadixCache):
739
822
  new_node.parent.children[self.get_child_key_fn(key)] = new_node
740
823
  return new_node
741
824
 
742
- def insert(self, key: List, value, chunked=False):
825
+ def insert(self, key: RadixKey, value=None, chunked=False):
826
+ key.token_ids = self.key_convert_fn(key.token_ids)
827
+
743
828
  if len(key) == 0:
744
829
  return 0
745
830
 
831
+ if self.is_eagle and value is not None:
832
+ # Make sure the value len equal to the EAGLE bigram key len
833
+ value = value[: len(key)]
834
+
746
835
  node = self.root_node
747
836
  child_key = self.get_child_key_fn(key)
748
837
  total_prefix_length = 0
@@ -757,7 +846,6 @@ class HiRadixCache(RadixCache):
757
846
  # change the reference if the node is evicted
758
847
  # this often happens in the case of KV cache recomputation
759
848
  node.value = value[:prefix_len]
760
- self.token_to_kv_pool_host.update_synced(node.host_value)
761
849
  self.evictable_size_ += len(node.value)
762
850
  else:
763
851
  self._inc_hit_count(node, chunked)
@@ -767,7 +855,6 @@ class HiRadixCache(RadixCache):
767
855
  new_node = self._split_node(node.key, node, prefix_len)
768
856
  if new_node.evicted:
769
857
  new_node.value = value[:prefix_len]
770
- self.token_to_kv_pool_host.update_synced(new_node.host_value)
771
858
  self.evictable_size_ += len(new_node.value)
772
859
  else:
773
860
  self._inc_hit_count(new_node, chunked)
@@ -797,7 +884,7 @@ class HiRadixCache(RadixCache):
797
884
  for idx in range(0, len(key), self.page_size):
798
885
  new_node.hash_value.append(
799
886
  self.cache_controller.get_hash_str(
800
- key[idx : idx + self.page_size],
887
+ key.token_ids[idx : idx + self.page_size],
801
888
  prior_hash=last_hash,
802
889
  )
803
890
  )
@@ -15,6 +15,8 @@ limitations under the License.
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ from sglang.srt.layers.attention.nsa import index_buf_accessor
19
+ from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
18
20
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
19
21
 
20
22
  """
@@ -1030,6 +1032,8 @@ class MLATokenToKVPool(KVCache):
1030
1032
  enable_memory_saver: bool,
1031
1033
  start_layer: Optional[int] = None,
1032
1034
  end_layer: Optional[int] = None,
1035
+ use_nsa: bool = False,
1036
+ override_kv_cache_dim: Optional[int] = None,
1033
1037
  ):
1034
1038
  super().__init__(
1035
1039
  size,
@@ -1044,6 +1048,14 @@ class MLATokenToKVPool(KVCache):
1044
1048
 
1045
1049
  self.kv_lora_rank = kv_lora_rank
1046
1050
  self.qk_rope_head_dim = qk_rope_head_dim
1051
+ self.use_nsa = use_nsa
1052
+ self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn
1053
+ # TODO do not hardcode
1054
+ self.kv_cache_dim = (
1055
+ 656
1056
+ if self.use_nsa and self.nsa_kv_cache_store_fp8
1057
+ else (kv_lora_rank + qk_rope_head_dim)
1058
+ )
1047
1059
 
1048
1060
  # for disagg with nvlink
1049
1061
  self.enable_custom_mem_pool = get_bool_env_var(
@@ -1067,7 +1079,7 @@ class MLATokenToKVPool(KVCache):
1067
1079
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
1068
1080
  self.kv_buffer = [
1069
1081
  torch.zeros(
1070
- (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
1082
+ (size + page_size, 1, self.kv_cache_dim),
1071
1083
  dtype=self.store_dtype,
1072
1084
  device=device,
1073
1085
  )
@@ -1130,6 +1142,7 @@ class MLATokenToKVPool(KVCache):
1130
1142
  cache_v: torch.Tensor,
1131
1143
  ):
1132
1144
  layer_id = layer.layer_id
1145
+ assert not (self.use_nsa and self.nsa_kv_cache_store_fp8)
1133
1146
  if cache_k.dtype != self.dtype:
1134
1147
  cache_k = cache_k.to(self.dtype)
1135
1148
  if self.store_dtype != self.dtype:
@@ -1147,16 +1160,28 @@ class MLATokenToKVPool(KVCache):
1147
1160
  cache_k_rope: torch.Tensor,
1148
1161
  ):
1149
1162
  layer_id = layer.layer_id
1150
- if cache_k_nope.dtype != self.dtype:
1151
- cache_k_nope = cache_k_nope.to(self.dtype)
1152
- cache_k_rope = cache_k_rope.to(self.dtype)
1153
- if self.store_dtype != self.dtype:
1154
- cache_k_nope = cache_k_nope.view(self.store_dtype)
1155
- cache_k_rope = cache_k_rope.view(self.store_dtype)
1156
1163
 
1157
- set_mla_kv_buffer_triton(
1158
- self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
1159
- )
1164
+ if self.use_nsa and self.nsa_kv_cache_store_fp8:
1165
+ # original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here
1166
+ # TODO no need to cat
1167
+ cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1)
1168
+ cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1)
1169
+ cache_k = cache_k.view(self.store_dtype)
1170
+ self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
1171
+ else:
1172
+ if cache_k_nope.dtype != self.dtype:
1173
+ cache_k_nope = cache_k_nope.to(self.dtype)
1174
+ cache_k_rope = cache_k_rope.to(self.dtype)
1175
+ if self.store_dtype != self.dtype:
1176
+ cache_k_nope = cache_k_nope.view(self.store_dtype)
1177
+ cache_k_rope = cache_k_rope.view(self.store_dtype)
1178
+
1179
+ set_mla_kv_buffer_triton(
1180
+ self.kv_buffer[layer_id - self.start_layer],
1181
+ loc,
1182
+ cache_k_nope,
1183
+ cache_k_rope,
1184
+ )
1160
1185
 
1161
1186
  def get_cpu_copy(self, indices):
1162
1187
  torch.cuda.synchronize()
@@ -1186,6 +1211,103 @@ class MLATokenToKVPool(KVCache):
1186
1211
  torch.cuda.synchronize()
1187
1212
 
1188
1213
 
1214
+ class NSATokenToKVPool(MLATokenToKVPool):
1215
+ def __init__(
1216
+ self,
1217
+ size: int,
1218
+ page_size: int,
1219
+ kv_lora_rank: int,
1220
+ dtype: torch.dtype,
1221
+ qk_rope_head_dim: int,
1222
+ layer_num: int,
1223
+ device: str,
1224
+ index_head_dim: int,
1225
+ enable_memory_saver: bool,
1226
+ start_layer: Optional[int] = None,
1227
+ end_layer: Optional[int] = None,
1228
+ ):
1229
+ super().__init__(
1230
+ size,
1231
+ page_size,
1232
+ dtype,
1233
+ kv_lora_rank,
1234
+ qk_rope_head_dim,
1235
+ layer_num,
1236
+ device,
1237
+ enable_memory_saver,
1238
+ start_layer,
1239
+ end_layer,
1240
+ use_nsa=True,
1241
+ )
1242
+ # self.index_k_dtype = torch.float8_e4m3fn
1243
+ # self.index_k_scale_dtype = torch.float32
1244
+ self.index_head_dim = index_head_dim
1245
+ # num head == 1 and head dim == 128 for index_k in NSA
1246
+ assert index_head_dim == 128
1247
+
1248
+ self.quant_block_size = 128
1249
+
1250
+ assert self.page_size == 64
1251
+ self.index_k_with_scale_buffer = [
1252
+ torch.zeros(
1253
+ # Layout:
1254
+ # ref: test_attention.py :: kv_cache_cast_to_fp8
1255
+ # shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
1256
+ # data: for page i,
1257
+ # * buf[i, :page_size * head_dim] for fp8 data
1258
+ # * buf[i, page_size * head_dim:].view(float32) for scale
1259
+ (
1260
+ (size + page_size + 1) // self.page_size,
1261
+ self.page_size
1262
+ * (index_head_dim + index_head_dim // self.quant_block_size * 4),
1263
+ ),
1264
+ dtype=torch.uint8,
1265
+ device=device,
1266
+ )
1267
+ for _ in range(layer_num)
1268
+ ]
1269
+
1270
+ def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
1271
+ if self.layer_transfer_counter is not None:
1272
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
1273
+ return self.index_k_with_scale_buffer[layer_id - self.start_layer]
1274
+
1275
+ def get_index_k_continuous(
1276
+ self,
1277
+ layer_id: int,
1278
+ seq_len: int,
1279
+ page_indices: torch.Tensor,
1280
+ ):
1281
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
1282
+ return index_buf_accessor.GetK.execute(
1283
+ self, buf, seq_len=seq_len, page_indices=page_indices
1284
+ )
1285
+
1286
+ def get_index_k_scale_continuous(
1287
+ self,
1288
+ layer_id: int,
1289
+ seq_len: int,
1290
+ page_indices: torch.Tensor,
1291
+ ):
1292
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
1293
+ return index_buf_accessor.GetS.execute(
1294
+ self, buf, seq_len=seq_len, page_indices=page_indices
1295
+ )
1296
+
1297
+ # TODO rename later (currently use diff name to avoid confusion)
1298
+ def set_index_k_and_scale_buffer(
1299
+ self,
1300
+ layer_id: int,
1301
+ loc: torch.Tensor,
1302
+ index_k: torch.Tensor,
1303
+ index_k_scale: torch.Tensor,
1304
+ ) -> None:
1305
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
1306
+ index_buf_accessor.SetKAndS.execute(
1307
+ pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
1308
+ )
1309
+
1310
+
1189
1311
  class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1190
1312
  def __init__(
1191
1313
  self,
@@ -1194,6 +1316,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1194
1316
  dtype: torch.dtype,
1195
1317
  kv_lora_rank: int,
1196
1318
  qk_rope_head_dim: int,
1319
+ index_head_dim: Optional[int],
1197
1320
  layer_num: int,
1198
1321
  device: str,
1199
1322
  enable_memory_saver: bool,
@@ -1213,6 +1336,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1213
1336
 
1214
1337
  self.kv_lora_rank = kv_lora_rank
1215
1338
  self.qk_rope_head_dim = qk_rope_head_dim
1339
+ self.index_head_dim = index_head_dim
1216
1340
 
1217
1341
  self.custom_mem_pool = None
1218
1342
 
@@ -1240,6 +1364,18 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1240
1364
  dtype=self.store_dtype,
1241
1365
  device=self.device,
1242
1366
  )
1367
+ if self.index_head_dim is not None:
1368
+ self.index_k_buffer = torch.zeros(
1369
+ (
1370
+ layer_num,
1371
+ self.size // self.page_size + 1,
1372
+ self.page_size,
1373
+ 1,
1374
+ self.index_head_dim,
1375
+ ),
1376
+ dtype=self.store_dtype,
1377
+ device=self.device,
1378
+ )
1243
1379
 
1244
1380
  self._finalize_allocation_log(size)
1245
1381
 
@@ -1251,6 +1387,10 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1251
1387
  kv_size_bytes += get_tensor_size_bytes(k_cache)
1252
1388
  for v_cache in self.v_buffer:
1253
1389
  kv_size_bytes += get_tensor_size_bytes(v_cache)
1390
+ if self.index_head_dim is not None:
1391
+ assert hasattr(self, "index_k_buffer")
1392
+ for index_k_cache in self.index_k_buffer:
1393
+ kv_size_bytes += get_tensor_size_bytes(index_k_cache)
1254
1394
  return kv_size_bytes
1255
1395
 
1256
1396
  def get_kv_buffer(self, layer_id: int):
@@ -1277,6 +1417,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1277
1417
  return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
1278
1418
  return self.v_buffer[layer_id - self.start_layer]
1279
1419
 
1420
+ def get_index_k_buffer(self, layer_id: int):
1421
+ if self.layer_transfer_counter is not None:
1422
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
1423
+
1424
+ if self.store_dtype != self.dtype:
1425
+ return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)
1426
+ return self.index_k_buffer[layer_id - self.start_layer]
1427
+
1280
1428
  # for disagg
1281
1429
  def get_contiguous_buf_infos(self):
1282
1430
  # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
@@ -1289,6 +1437,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1289
1437
  kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
1290
1438
  self.v_buffer[i][0].nbytes for i in range(self.layer_num)
1291
1439
  ]
1440
+ if self.index_head_dim is not None:
1441
+ kv_data_ptrs += [
1442
+ self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)
1443
+ ]
1444
+ kv_data_lens += [
1445
+ self.index_k_buffer[i].nbytes for i in range(self.layer_num)
1446
+ ]
1447
+ kv_item_lens += [
1448
+ self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)
1449
+ ]
1292
1450
  return kv_data_ptrs, kv_data_lens, kv_item_lens
1293
1451
 
1294
1452
  def set_kv_buffer(
@@ -1325,6 +1483,26 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1325
1483
  cache_v.view(-1, 1, self.qk_rope_head_dim),
1326
1484
  )
1327
1485
 
1486
+ def set_index_k_buffer(
1487
+ self,
1488
+ layer_id: int,
1489
+ loc: torch.Tensor,
1490
+ index_k: torch.Tensor,
1491
+ ):
1492
+ if index_k.dtype != self.dtype:
1493
+ index_k = index_k.to(self.dtype)
1494
+
1495
+ if self.store_dtype != self.dtype:
1496
+ index_k = index_k.view(self.store_dtype)
1497
+
1498
+ torch_npu.npu_scatter_nd_update_(
1499
+ self.index_k_buffer[layer_id - self.start_layer].view(
1500
+ -1, 1, self.index_head_dim
1501
+ ),
1502
+ loc.view(-1, 1),
1503
+ index_k.view(-1, 1, self.index_head_dim),
1504
+ )
1505
+
1328
1506
 
1329
1507
  class DoubleSparseTokenToKVPool(KVCache):
1330
1508
  def __init__(