sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ from __future__ import annotations
17
+
16
18
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
17
19
 
18
20
  """
@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
27
29
  import abc
28
30
  import logging
29
31
  from contextlib import nullcontext
30
- from typing import Dict, List, Optional, Tuple, Union
32
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
31
33
 
32
34
  import numpy as np
33
35
  import torch
@@ -36,12 +38,18 @@ import triton.language as tl
36
38
 
37
39
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
38
40
  from sglang.srt.layers.radix_attention import RadixAttention
39
- from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
41
+ from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
42
+
43
+ if TYPE_CHECKING:
44
+ from sglang.srt.managers.cache_controller import LayerDoneCounter
40
45
 
41
46
  logger = logging.getLogger(__name__)
42
47
 
43
48
  GB = 1024 * 1024 * 1024
44
49
  _is_cuda = is_cuda()
50
+ _is_npu = is_npu()
51
+ if _is_npu:
52
+ import torch_npu
45
53
 
46
54
 
47
55
  class ReqToTokenPool:
@@ -94,6 +102,207 @@ class ReqToTokenPool:
94
102
  self.free_slots = list(range(self.size))
95
103
 
96
104
 
105
+ class MambaPool:
106
+ def __init__(
107
+ self,
108
+ size: int,
109
+ conv_dtype: torch.dtype,
110
+ ssm_dtype: torch.dtype,
111
+ num_mamba_layers: int,
112
+ conv_state_shape: Tuple[int, int],
113
+ temporal_state_shape: Tuple[int, int],
114
+ device: str,
115
+ speculative_num_draft_tokens: Optional[int] = None,
116
+ ):
117
+ conv_state = torch.zeros(
118
+ size=(num_mamba_layers, size + 1) + conv_state_shape,
119
+ dtype=conv_dtype,
120
+ device=device,
121
+ )
122
+ temporal_state = torch.zeros(
123
+ size=(num_mamba_layers, size + 1) + temporal_state_shape,
124
+ dtype=ssm_dtype,
125
+ device=device,
126
+ )
127
+ if speculative_num_draft_tokens is not None:
128
+ # Cache intermediate SSM states per draft token during target verify
129
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
130
+ intermediate_ssm_state_cache = torch.empty(
131
+ size=(
132
+ num_mamba_layers,
133
+ size + 1,
134
+ speculative_num_draft_tokens,
135
+ temporal_state_shape[0],
136
+ temporal_state_shape[1],
137
+ temporal_state_shape[2],
138
+ ),
139
+ dtype=ssm_dtype,
140
+ device="cuda",
141
+ )
142
+ # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
143
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
144
+ intermediate_conv_window_cache = torch.empty(
145
+ size=(
146
+ num_mamba_layers,
147
+ size + 1,
148
+ speculative_num_draft_tokens,
149
+ conv_state_shape[0],
150
+ conv_state_shape[1],
151
+ ),
152
+ dtype=conv_dtype,
153
+ device="cuda",
154
+ )
155
+ self.mamba_cache = (
156
+ conv_state,
157
+ temporal_state,
158
+ intermediate_ssm_state_cache,
159
+ intermediate_conv_window_cache,
160
+ )
161
+ else:
162
+ self.mamba_cache = (conv_state, temporal_state)
163
+ self.size = size
164
+ self.free_slots = list(range(size))
165
+ self.mem_usage = self.get_mamba_size() / GB
166
+ logger.info(
167
+ f"Mamba Cache is allocated. "
168
+ f"conv_state size: {conv_state.numel() * conv_state.itemsize / GB:.2f}GB, "
169
+ f"ssm_state size: {temporal_state.numel() * temporal_state.itemsize / GB:.2f}GB "
170
+ )
171
+
172
+ def get_mamba_params_all_layers(self):
173
+ return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
174
+
175
+ def get_mamba_params(self, layer_id: int):
176
+ return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
177
+
178
+ def get_mamba_size(self):
179
+ return (
180
+ np.prod(self.mamba_cache[0].shape) * self.mamba_cache[0].dtype.itemsize
181
+ + np.prod(self.mamba_cache[1].shape) * self.mamba_cache[1].dtype.itemsize
182
+ )
183
+
184
+ def available_size(self):
185
+ return len(self.free_slots)
186
+
187
+ def alloc(self, need_size: int) -> Optional[List[int]]:
188
+ if need_size > len(self.free_slots):
189
+ return None
190
+
191
+ select_index = self.free_slots[:need_size]
192
+ self.free_slots = self.free_slots[need_size:]
193
+
194
+ return select_index
195
+
196
+ def free(self, free_index: Union[int, List[int]]):
197
+ if isinstance(free_index, (int,)):
198
+ self.free_slots.append(free_index)
199
+ else:
200
+ self.free_slots.extend(free_index)
201
+ self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
202
+
203
+ def clear(self):
204
+ self.free_slots = list(range(self.size))
205
+
206
+
207
+ class HybridReqToTokenPool(ReqToTokenPool):
208
+ """A memory pool that maps a request to its token locations."""
209
+
210
+ def __init__(
211
+ self,
212
+ size: int,
213
+ max_context_len: int,
214
+ device: str,
215
+ enable_memory_saver: bool,
216
+ conv_dtype: torch.dtype,
217
+ ssm_dtype: torch.dtype,
218
+ mamba_layers: List[int],
219
+ conv_state_shape: Tuple[int, int],
220
+ temporal_state_shape: Tuple[int, int],
221
+ speculative_num_draft_tokens: int,
222
+ ):
223
+ super().__init__(
224
+ size=size,
225
+ max_context_len=max_context_len,
226
+ device=device,
227
+ enable_memory_saver=enable_memory_saver,
228
+ )
229
+
230
+ self.mamba_pool = MambaPool(
231
+ size,
232
+ conv_dtype,
233
+ ssm_dtype,
234
+ len(mamba_layers),
235
+ conv_state_shape,
236
+ temporal_state_shape,
237
+ device,
238
+ speculative_num_draft_tokens,
239
+ )
240
+ self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
241
+
242
+ self.device = device
243
+ self.req_index_to_mamba_index_mapping: torch.Tensor = torch.empty(
244
+ size, dtype=torch.int32, device=self.device
245
+ )
246
+
247
+ self.rid_to_mamba_index_mapping: Dict[str, int] = {}
248
+ self.mamba_index_to_rid_mapping: Dict[int, str] = {}
249
+
250
+ # For chunk prefill req, we do not need to allocate mamba cache,
251
+ # We could use allocated mamba cache instead.
252
+ def alloc(
253
+ self, need_size: int, reqs: Optional[List["Req"]] = None
254
+ ) -> Optional[List[int]]:
255
+ select_index = super().alloc(need_size)
256
+ if select_index == None:
257
+ return None
258
+
259
+ mamba_index = []
260
+ for req in reqs:
261
+ rid = req.rid
262
+ if rid in self.rid_to_mamba_index_mapping:
263
+ mid = self.rid_to_mamba_index_mapping[rid]
264
+ elif (mid := self.mamba_pool.alloc(1)) is not None:
265
+ mid = mid[0]
266
+ self.rid_to_mamba_index_mapping[rid] = mid
267
+ self.mamba_index_to_rid_mapping[mid] = rid
268
+ mamba_index.append(mid)
269
+ assert len(select_index) == len(
270
+ mamba_index
271
+ ), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
272
+ self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
273
+ mamba_index, dtype=torch.int32, device=self.device
274
+ )
275
+ return select_index
276
+
277
+ def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
278
+ return self.req_index_to_mamba_index_mapping[req_indices]
279
+
280
+ def get_mamba_params(self, layer_id: int):
281
+ assert layer_id in self.mamba_map
282
+ return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
283
+
284
+ def get_mamba_params_all_layers(self):
285
+ return self.mamba_pool.get_mamba_params_all_layers()
286
+
287
+ # For chunk prefill, we can not free mamba cache, we need use it in the future
288
+ def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
289
+ super().free(free_index)
290
+ if free_mamba_cache:
291
+ mamba_index = self.req_index_to_mamba_index_mapping[free_index]
292
+ mamba_index_list = mamba_index.tolist()
293
+ if isinstance(mamba_index_list, int):
294
+ mamba_index_list = [mamba_index_list]
295
+ self.mamba_pool.free(mamba_index_list)
296
+ for mid in mamba_index_list:
297
+ rid = self.mamba_index_to_rid_mapping[mid]
298
+ self.mamba_index_to_rid_mapping.pop(mid)
299
+ self.rid_to_mamba_index_mapping.pop(rid)
300
+
301
+ def clear(self):
302
+ super().clear()
303
+ self.mamba_pool.clear()
304
+
305
+
97
306
  class KVCache(abc.ABC):
98
307
  @abc.abstractmethod
99
308
  def __init__(
@@ -127,6 +336,29 @@ class KVCache(abc.ABC):
127
336
  # used for chunked cpu-offloading
128
337
  self.cpu_offloading_chunk_size = 8192
129
338
 
339
+ # default state for optional layer-wise transfer control
340
+ self.layer_transfer_counter = None
341
+
342
+ def _finalize_allocation_log(self, num_tokens: int):
343
+ """Common logging and mem_usage computation for KV cache allocation.
344
+ Supports both tuple (K, V) size returns and single KV size returns.
345
+ """
346
+ kv_size_bytes = self.get_kv_size_bytes()
347
+ if isinstance(kv_size_bytes, tuple):
348
+ k_size, v_size = kv_size_bytes
349
+ k_size_GB = k_size / GB
350
+ v_size_GB = v_size / GB
351
+ logger.info(
352
+ f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB"
353
+ )
354
+ self.mem_usage = k_size_GB + v_size_GB
355
+ else:
356
+ kv_size_GB = kv_size_bytes / GB
357
+ logger.info(
358
+ f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB"
359
+ )
360
+ self.mem_usage = kv_size_GB
361
+
130
362
  @abc.abstractmethod
131
363
  def get_key_buffer(self, layer_id: int) -> torch.Tensor:
132
364
  raise NotImplementedError()
@@ -149,7 +381,7 @@ class KVCache(abc.ABC):
149
381
  ) -> None:
150
382
  raise NotImplementedError()
151
383
 
152
- def register_layer_transfer_counter(self, layer_transfer_counter):
384
+ def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
153
385
  self.layer_transfer_counter = layer_transfer_counter
154
386
 
155
387
  def get_cpu_copy(self, indices):
@@ -202,15 +434,9 @@ class MHATokenToKVPool(KVCache):
202
434
 
203
435
  self._create_buffers()
204
436
 
205
- self.layer_transfer_counter = None
206
437
  self.device_module = torch.get_device_module(self.device)
207
438
  self.alt_stream = self.device_module.Stream() if _is_cuda else None
208
-
209
- k_size, v_size = self.get_kv_size_bytes()
210
- logger.info(
211
- f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
212
- )
213
- self.mem_usage = (k_size + v_size) / GB
439
+ self._finalize_allocation_log(size)
214
440
 
215
441
  def _create_buffers(self):
216
442
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
@@ -349,7 +575,6 @@ class MHATokenToKVPool(KVCache):
349
575
  # same applies to get_value_buffer and get_kv_buffer
350
576
  if self.layer_transfer_counter is not None:
351
577
  self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
352
-
353
578
  return self._get_key_buffer(layer_id)
354
579
 
355
580
  def _get_value_buffer(self, layer_id: int):
@@ -417,50 +642,119 @@ class MHATokenToKVPool(KVCache):
417
642
  )
418
643
 
419
644
 
420
- class SWAKVPool(KVCache):
421
- """KV cache with separate pools for full and SWA attention layers."""
645
+ class HybridLinearKVPool(KVCache):
646
+ """KV cache with separate pools for full and linear attention layers."""
422
647
 
423
648
  def __init__(
424
649
  self,
425
650
  size: int,
426
- size_swa: int,
427
651
  dtype: torch.dtype,
428
652
  head_num: int,
429
653
  head_dim: int,
430
- swa_attention_layer_ids: List[int],
431
654
  full_attention_layer_ids: List[int],
432
655
  enable_kvcache_transpose: bool,
433
656
  device: str,
434
657
  ):
435
658
  self.size = size
436
- self.size_swa = size_swa
437
659
  self.dtype = dtype
438
660
  self.device = device
439
- self.swa_layer_nums = len(swa_attention_layer_ids)
440
661
  self.full_layer_nums = len(full_attention_layer_ids)
441
662
  self.page_size = 1
442
663
  # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
443
664
  assert not enable_kvcache_transpose
444
- TokenToKVPoolClass = MHATokenToKVPool
445
- self.swa_kv_pool = TokenToKVPoolClass(
446
- size=size_swa,
665
+ self.full_kv_pool = MHATokenToKVPool(
666
+ size=size,
447
667
  page_size=self.page_size,
448
668
  dtype=dtype,
449
669
  head_num=head_num,
450
670
  head_dim=head_dim,
451
- layer_num=self.swa_layer_nums,
671
+ layer_num=self.full_layer_nums,
452
672
  device=device,
453
673
  enable_memory_saver=False,
454
674
  )
455
- self.full_kv_pool = TokenToKVPoolClass(
675
+ self.full_attention_layer_id_mapping = {
676
+ id: i for i, id in enumerate(full_attention_layer_ids)
677
+ }
678
+ k_size, v_size = self.get_kv_size_bytes()
679
+ self.mem_usage = (k_size + v_size) / GB
680
+
681
+ def get_kv_size_bytes(self):
682
+ return self.full_kv_pool.get_kv_size_bytes()
683
+
684
+ def get_contiguous_buf_infos(self):
685
+ return self.full_kv_pool.get_contiguous_buf_infos()
686
+
687
+ def _transfer_full_attention_id(self, layer_id: int):
688
+ if layer_id not in self.full_attention_layer_id_mapping:
689
+ raise ValueError(
690
+ f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
691
+ )
692
+ return self.full_attention_layer_id_mapping[layer_id]
693
+
694
+ def get_key_buffer(self, layer_id: int):
695
+ layer_id = self._transfer_full_attention_id(layer_id)
696
+ return self.full_kv_pool.get_key_buffer(layer_id)
697
+
698
+ def get_value_buffer(self, layer_id: int):
699
+ layer_id = self._transfer_full_attention_id(layer_id)
700
+ return self.full_kv_pool.get_value_buffer(layer_id)
701
+
702
+ def get_kv_buffer(self, layer_id: int):
703
+ layer_id = self._transfer_full_attention_id(layer_id)
704
+ return self.full_kv_pool.get_kv_buffer(layer_id)
705
+
706
+ def set_kv_buffer(
707
+ self,
708
+ layer: RadixAttention,
709
+ loc: torch.Tensor,
710
+ cache_k: torch.Tensor,
711
+ cache_v: torch.Tensor,
712
+ k_scale: float = 1.0,
713
+ v_scale: float = 1.0,
714
+ ):
715
+ layer_id = self._transfer_full_attention_id(layer.layer_id)
716
+ self.full_kv_pool.set_kv_buffer(
717
+ None,
718
+ loc,
719
+ cache_k,
720
+ cache_v,
721
+ k_scale,
722
+ v_scale,
723
+ layer_id_override=layer_id,
724
+ )
725
+
726
+
727
+ class SWAKVPool(KVCache):
728
+ """KV cache with separate pools for full and SWA attention layers."""
729
+
730
+ def __init__(
731
+ self,
732
+ size: int,
733
+ size_swa: int,
734
+ swa_attention_layer_ids: List[int],
735
+ full_attention_layer_ids: List[int],
736
+ enable_kvcache_transpose: bool,
737
+ token_to_kv_pool_class: KVCache = MHATokenToKVPool,
738
+ **kwargs,
739
+ ):
740
+ self.size = size
741
+ self.size_swa = size_swa
742
+ self.swa_layer_nums = len(swa_attention_layer_ids)
743
+ self.full_layer_nums = len(full_attention_layer_ids)
744
+ kwargs["page_size"] = 1
745
+ kwargs["enable_memory_saver"] = False
746
+ # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
747
+ assert not enable_kvcache_transpose
748
+
749
+ self.swa_kv_pool = token_to_kv_pool_class(
750
+ size=size_swa,
751
+ layer_num=self.swa_layer_nums,
752
+ **kwargs,
753
+ )
754
+ self.full_kv_pool = token_to_kv_pool_class(
456
755
  size=size,
457
- page_size=self.page_size,
458
- dtype=dtype,
459
- head_num=head_num,
460
- head_dim=head_dim,
461
756
  layer_num=self.full_layer_nums,
462
- device=device,
463
- enable_memory_saver=False,
757
+ **kwargs,
464
758
  )
465
759
  self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
466
760
  for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
@@ -624,8 +918,6 @@ class AscendTokenToKVPool(MHATokenToKVPool):
624
918
  cache_k = cache_k.view(self.store_dtype)
625
919
  cache_v = cache_v.view(self.store_dtype)
626
920
 
627
- import torch_npu
628
-
629
921
  torch_npu._npu_reshape_and_cache(
630
922
  key=cache_k,
631
923
  value=cache_v,
@@ -767,13 +1059,7 @@ class MLATokenToKVPool(KVCache):
767
1059
  dtype=torch.uint64,
768
1060
  device=self.device,
769
1061
  )
770
- self.layer_transfer_counter = None
771
-
772
- kv_size = self.get_kv_size_bytes()
773
- logger.info(
774
- f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
775
- )
776
- self.mem_usage = kv_size / GB
1062
+ self._finalize_allocation_log(size)
777
1063
 
778
1064
  def get_kv_size_bytes(self):
779
1065
  assert hasattr(self, "kv_buffer")
@@ -912,31 +1198,77 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
912
1198
 
913
1199
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
914
1200
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
915
- self.kv_buffer = torch.zeros(
1201
+ self.k_buffer = torch.zeros(
1202
+ (
1203
+ layer_num,
1204
+ self.size // self.page_size + 1,
1205
+ self.page_size,
1206
+ 1,
1207
+ self.kv_lora_rank,
1208
+ ),
1209
+ dtype=self.store_dtype,
1210
+ device=self.device,
1211
+ )
1212
+ self.v_buffer = torch.zeros(
916
1213
  (
917
1214
  layer_num,
918
1215
  self.size // self.page_size + 1,
919
1216
  self.page_size,
920
- self.kv_lora_rank + self.qk_rope_head_dim,
1217
+ 1,
1218
+ self.qk_rope_head_dim,
921
1219
  ),
922
1220
  dtype=self.store_dtype,
923
1221
  device=self.device,
924
1222
  )
925
1223
 
926
- self.layer_transfer_counter = None
1224
+ self._finalize_allocation_log(size)
927
1225
 
928
- kv_size = self.get_kv_size_bytes()
929
- logger.info(
930
- f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
1226
+ def get_kv_size_bytes(self):
1227
+ assert hasattr(self, "k_buffer")
1228
+ assert hasattr(self, "v_buffer")
1229
+ kv_size_bytes = 0
1230
+ for k_cache in self.k_buffer:
1231
+ kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
1232
+ for v_cache in self.v_buffer:
1233
+ kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
1234
+ return kv_size_bytes
1235
+
1236
+ def get_kv_buffer(self, layer_id: int):
1237
+ if self.layer_transfer_counter is not None:
1238
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
1239
+ return (
1240
+ self.k_buffer[layer_id - self.start_layer],
1241
+ self.v_buffer[layer_id - self.start_layer],
931
1242
  )
932
- self.mem_usage = kv_size / GB
1243
+
1244
+ def get_key_buffer(self, layer_id: int):
1245
+ if self.layer_transfer_counter is not None:
1246
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
1247
+
1248
+ if self.store_dtype != self.dtype:
1249
+ return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
1250
+ return self.k_buffer[layer_id - self.start_layer]
1251
+
1252
+ def get_value_buffer(self, layer_id: int):
1253
+ if self.layer_transfer_counter is not None:
1254
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
1255
+
1256
+ if self.store_dtype != self.dtype:
1257
+ return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
1258
+ return self.v_buffer[layer_id - self.start_layer]
933
1259
 
934
1260
  # for disagg
935
1261
  def get_contiguous_buf_infos(self):
936
1262
  # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
937
- kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
938
- kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
939
- kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
1263
+ kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [
1264
+ self.v_buffer[i].data_ptr() for i in range(self.layer_num)
1265
+ ]
1266
+ kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [
1267
+ self.v_buffer[i].nbytes for i in range(self.layer_num)
1268
+ ]
1269
+ kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
1270
+ self.v_buffer[i][0].nbytes for i in range(self.layer_num)
1271
+ ]
940
1272
  return kv_data_ptrs, kv_data_lens, kv_item_lens
941
1273
 
942
1274
  def set_kv_buffer(
@@ -949,18 +1281,28 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
949
1281
  layer_id = layer.layer_id
950
1282
  if cache_k.dtype != self.dtype:
951
1283
  cache_k = cache_k.to(self.dtype)
1284
+ cache_v = cache_v.to(self.dtype)
952
1285
 
953
1286
  if self.store_dtype != self.dtype:
954
1287
  cache_k = cache_k.view(self.store_dtype)
1288
+ cache_v = cache_v.view(self.store_dtype)
955
1289
 
956
- import torch_npu
1290
+ if cache_v is None:
1291
+ cache_k, cache_v = cache_k.split(
1292
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
1293
+ )
957
1294
 
958
- torch_npu._npu_reshape_and_cache_siso(
959
- key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
960
- key_cache=self.kv_buffer[layer_id - self.start_layer].view(
961
- -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
1295
+ torch_npu.npu_scatter_nd_update_(
1296
+ self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),
1297
+ loc.view(-1, 1),
1298
+ cache_k.view(-1, 1, self.kv_lora_rank),
1299
+ )
1300
+ torch_npu.npu_scatter_nd_update_(
1301
+ self.v_buffer[layer_id - self.start_layer].view(
1302
+ -1, 1, self.qk_rope_head_dim
962
1303
  ),
963
- slot_indices=loc,
1304
+ loc.view(-1, 1),
1305
+ cache_v.view(-1, 1, self.qk_rope_head_dim),
964
1306
  )
965
1307
 
966
1308