sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ from __future__ import annotations
17
+
16
18
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
17
19
 
18
20
  """
@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
27
29
  import abc
28
30
  import logging
29
31
  from contextlib import nullcontext
30
- from typing import Dict, List, Optional, Tuple, Union
32
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
31
33
 
32
34
  import numpy as np
33
35
  import torch
@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
38
40
  from sglang.srt.layers.radix_attention import RadixAttention
39
41
  from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
40
42
 
43
+ if TYPE_CHECKING:
44
+ from sglang.srt.managers.cache_controller import LayerDoneCounter
45
+
41
46
  logger = logging.getLogger(__name__)
42
47
 
43
48
  GB = 1024 * 1024 * 1024
@@ -47,6 +52,10 @@ if _is_npu:
47
52
  import torch_npu
48
53
 
49
54
 
55
+ def get_tensor_size_bytes(t: torch.Tensor):
56
+ return np.prod(t.shape) * t.dtype.itemsize
57
+
58
+
50
59
  class ReqToTokenPool:
51
60
  """A memory pool that maps a request to its token locations."""
52
61
 
@@ -97,6 +106,211 @@ class ReqToTokenPool:
97
106
  self.free_slots = list(range(self.size))
98
107
 
99
108
 
109
+ class MambaPool:
110
+ def __init__(
111
+ self,
112
+ size: int,
113
+ conv_dtype: torch.dtype,
114
+ ssm_dtype: torch.dtype,
115
+ num_mamba_layers: int,
116
+ conv_state_shape: Tuple[int, int],
117
+ temporal_state_shape: Tuple[int, int],
118
+ device: str,
119
+ speculative_num_draft_tokens: Optional[int] = None,
120
+ ):
121
+ conv_state = torch.zeros(
122
+ size=(num_mamba_layers, size + 1) + conv_state_shape,
123
+ dtype=conv_dtype,
124
+ device=device,
125
+ )
126
+ temporal_state = torch.zeros(
127
+ size=(num_mamba_layers, size + 1) + temporal_state_shape,
128
+ dtype=ssm_dtype,
129
+ device=device,
130
+ )
131
+ if speculative_num_draft_tokens is not None:
132
+ # Cache intermediate SSM states per draft token during target verify
133
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
134
+ intermediate_ssm_state_cache = torch.zeros(
135
+ size=(
136
+ num_mamba_layers,
137
+ size + 1,
138
+ speculative_num_draft_tokens,
139
+ temporal_state_shape[0],
140
+ temporal_state_shape[1],
141
+ temporal_state_shape[2],
142
+ ),
143
+ dtype=ssm_dtype,
144
+ device="cuda",
145
+ )
146
+ # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
147
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
148
+ intermediate_conv_window_cache = torch.zeros(
149
+ size=(
150
+ num_mamba_layers,
151
+ size + 1,
152
+ speculative_num_draft_tokens,
153
+ conv_state_shape[0],
154
+ conv_state_shape[1],
155
+ ),
156
+ dtype=conv_dtype,
157
+ device="cuda",
158
+ )
159
+ self.mamba_cache = (
160
+ conv_state,
161
+ temporal_state,
162
+ intermediate_ssm_state_cache,
163
+ intermediate_conv_window_cache,
164
+ )
165
+ logger.info(
166
+ f"Mamba Cache is allocated. "
167
+ f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
168
+ f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
169
+ f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
170
+ f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
171
+ )
172
+ else:
173
+ self.mamba_cache = (conv_state, temporal_state)
174
+ logger.info(
175
+ f"Mamba Cache is allocated. "
176
+ f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
177
+ f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
178
+ )
179
+ self.size = size
180
+ self.free_slots = list(range(size))
181
+ self.mem_usage = self.get_mamba_size() / GB
182
+
183
+ def get_mamba_params_all_layers(self):
184
+ return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
185
+
186
+ def get_mamba_params(self, layer_id: int):
187
+ return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
188
+
189
+ def get_mamba_size(self):
190
+ return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
191
+
192
+ def available_size(self):
193
+ return len(self.free_slots)
194
+
195
+ def alloc(self, need_size: int) -> Optional[List[int]]:
196
+ if need_size > len(self.free_slots):
197
+ return None
198
+
199
+ select_index = self.free_slots[:need_size]
200
+ self.free_slots = self.free_slots[need_size:]
201
+
202
+ return select_index
203
+
204
+ def free(self, free_index: Union[int, List[int]]):
205
+ if isinstance(free_index, (int,)):
206
+ self.free_slots.append(free_index)
207
+ else:
208
+ self.free_slots.extend(free_index)
209
+ self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
210
+
211
+ def clear(self):
212
+ self.free_slots = list(range(self.size))
213
+
214
+
215
+ class HybridReqToTokenPool(ReqToTokenPool):
216
+ """A memory pool that maps a request to its token locations."""
217
+
218
+ def __init__(
219
+ self,
220
+ size: int,
221
+ max_context_len: int,
222
+ device: str,
223
+ enable_memory_saver: bool,
224
+ conv_dtype: torch.dtype,
225
+ ssm_dtype: torch.dtype,
226
+ mamba_layers: List[int],
227
+ conv_state_shape: Tuple[int, int],
228
+ temporal_state_shape: Tuple[int, int],
229
+ speculative_num_draft_tokens: int,
230
+ ):
231
+ super().__init__(
232
+ size=size,
233
+ max_context_len=max_context_len,
234
+ device=device,
235
+ enable_memory_saver=enable_memory_saver,
236
+ )
237
+
238
+ self.mamba_pool = MambaPool(
239
+ size,
240
+ conv_dtype,
241
+ ssm_dtype,
242
+ len(mamba_layers),
243
+ conv_state_shape,
244
+ temporal_state_shape,
245
+ device,
246
+ speculative_num_draft_tokens,
247
+ )
248
+ self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
249
+
250
+ self.device = device
251
+ self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
252
+ size, dtype=torch.int32, device=self.device
253
+ )
254
+
255
+ self.rid_to_mamba_index_mapping: Dict[str, int] = {}
256
+ self.mamba_index_to_rid_mapping: Dict[int, str] = {}
257
+
258
+ # For chunk prefill req, we do not need to allocate mamba cache,
259
+ # We could use allocated mamba cache instead.
260
+ def alloc(
261
+ self, need_size: int, reqs: Optional[List["Req"]] = None
262
+ ) -> Optional[List[int]]:
263
+ select_index = super().alloc(need_size)
264
+ if select_index == None:
265
+ return None
266
+
267
+ mamba_index = []
268
+ for req in reqs:
269
+ rid = req.rid
270
+ if rid in self.rid_to_mamba_index_mapping:
271
+ mid = self.rid_to_mamba_index_mapping[rid]
272
+ elif (mid := self.mamba_pool.alloc(1)) is not None:
273
+ mid = mid[0]
274
+ self.rid_to_mamba_index_mapping[rid] = mid
275
+ self.mamba_index_to_rid_mapping[mid] = rid
276
+ mamba_index.append(mid)
277
+ assert len(select_index) == len(
278
+ mamba_index
279
+ ), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
280
+ self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
281
+ mamba_index, dtype=torch.int32, device=self.device
282
+ )
283
+ return select_index
284
+
285
+ def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
286
+ return self.req_index_to_mamba_index_mapping[req_indices]
287
+
288
+ def get_mamba_params(self, layer_id: int):
289
+ assert layer_id in self.mamba_map
290
+ return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
291
+
292
+ def get_mamba_params_all_layers(self):
293
+ return self.mamba_pool.get_mamba_params_all_layers()
294
+
295
+ # For chunk prefill, we can not free mamba cache, we need use it in the future
296
+ def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
297
+ super().free(free_index)
298
+ if free_mamba_cache:
299
+ mamba_index = self.req_index_to_mamba_index_mapping[free_index]
300
+ mamba_index_list = mamba_index.tolist()
301
+ if isinstance(mamba_index_list, int):
302
+ mamba_index_list = [mamba_index_list]
303
+ self.mamba_pool.free(mamba_index_list)
304
+ for mid in mamba_index_list:
305
+ rid = self.mamba_index_to_rid_mapping[mid]
306
+ self.mamba_index_to_rid_mapping.pop(mid)
307
+ self.rid_to_mamba_index_mapping.pop(rid)
308
+
309
+ def clear(self):
310
+ super().clear()
311
+ self.mamba_pool.clear()
312
+
313
+
100
314
  class KVCache(abc.ABC):
101
315
  @abc.abstractmethod
102
316
  def __init__(
@@ -130,6 +344,29 @@ class KVCache(abc.ABC):
130
344
  # used for chunked cpu-offloading
131
345
  self.cpu_offloading_chunk_size = 8192
132
346
 
347
+ # default state for optional layer-wise transfer control
348
+ self.layer_transfer_counter = None
349
+
350
+ def _finalize_allocation_log(self, num_tokens: int):
351
+ """Common logging and mem_usage computation for KV cache allocation.
352
+ Supports both tuple (K, V) size returns and single KV size returns.
353
+ """
354
+ kv_size_bytes = self.get_kv_size_bytes()
355
+ if isinstance(kv_size_bytes, tuple):
356
+ k_size, v_size = kv_size_bytes
357
+ k_size_GB = k_size / GB
358
+ v_size_GB = v_size / GB
359
+ logger.info(
360
+ f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB"
361
+ )
362
+ self.mem_usage = k_size_GB + v_size_GB
363
+ else:
364
+ kv_size_GB = kv_size_bytes / GB
365
+ logger.info(
366
+ f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB"
367
+ )
368
+ self.mem_usage = kv_size_GB
369
+
133
370
  @abc.abstractmethod
134
371
  def get_key_buffer(self, layer_id: int) -> torch.Tensor:
135
372
  raise NotImplementedError()
@@ -152,7 +389,7 @@ class KVCache(abc.ABC):
152
389
  ) -> None:
153
390
  raise NotImplementedError()
154
391
 
155
- def register_layer_transfer_counter(self, layer_transfer_counter):
392
+ def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
156
393
  self.layer_transfer_counter = layer_transfer_counter
157
394
 
158
395
  def get_cpu_copy(self, indices):
@@ -205,15 +442,9 @@ class MHATokenToKVPool(KVCache):
205
442
 
206
443
  self._create_buffers()
207
444
 
208
- self.layer_transfer_counter = None
209
445
  self.device_module = torch.get_device_module(self.device)
210
446
  self.alt_stream = self.device_module.Stream() if _is_cuda else None
211
-
212
- k_size, v_size = self.get_kv_size_bytes()
213
- logger.info(
214
- f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
215
- )
216
- self.mem_usage = (k_size + v_size) / GB
447
+ self._finalize_allocation_log(size)
217
448
 
218
449
  def _create_buffers(self):
219
450
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
@@ -269,10 +500,10 @@ class MHATokenToKVPool(KVCache):
269
500
  assert hasattr(self, "v_buffer")
270
501
  k_size_bytes = 0
271
502
  for k_cache in self.k_buffer:
272
- k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
503
+ k_size_bytes += get_tensor_size_bytes(k_cache)
273
504
  v_size_bytes = 0
274
505
  for v_cache in self.v_buffer:
275
- v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
506
+ v_size_bytes += get_tensor_size_bytes(v_cache)
276
507
  return k_size_bytes, v_size_bytes
277
508
 
278
509
  # for disagg
@@ -352,7 +583,6 @@ class MHATokenToKVPool(KVCache):
352
583
  # same applies to get_value_buffer and get_kv_buffer
353
584
  if self.layer_transfer_counter is not None:
354
585
  self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
355
-
356
586
  return self._get_key_buffer(layer_id)
357
587
 
358
588
  def _get_value_buffer(self, layer_id: int):
@@ -420,41 +650,31 @@ class MHATokenToKVPool(KVCache):
420
650
  )
421
651
 
422
652
 
423
- class SWAKVPool(KVCache):
424
- """KV cache with separate pools for full and SWA attention layers."""
653
+ class HybridLinearKVPool(KVCache):
654
+ """KV cache with separate pools for full and linear attention layers."""
425
655
 
426
656
  def __init__(
427
657
  self,
428
658
  size: int,
429
- size_swa: int,
430
659
  dtype: torch.dtype,
660
+ page_size: int,
431
661
  head_num: int,
432
662
  head_dim: int,
433
- swa_attention_layer_ids: List[int],
434
663
  full_attention_layer_ids: List[int],
435
664
  enable_kvcache_transpose: bool,
436
665
  device: str,
437
666
  ):
438
667
  self.size = size
439
- self.size_swa = size_swa
440
668
  self.dtype = dtype
441
669
  self.device = device
442
- self.swa_layer_nums = len(swa_attention_layer_ids)
443
670
  self.full_layer_nums = len(full_attention_layer_ids)
444
- self.page_size = 1
671
+ self.page_size = page_size
445
672
  # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
446
673
  assert not enable_kvcache_transpose
447
- TokenToKVPoolClass = MHATokenToKVPool
448
- self.swa_kv_pool = TokenToKVPoolClass(
449
- size=size_swa,
450
- page_size=self.page_size,
451
- dtype=dtype,
452
- head_num=head_num,
453
- head_dim=head_dim,
454
- layer_num=self.swa_layer_nums,
455
- device=device,
456
- enable_memory_saver=False,
457
- )
674
+ if _is_npu:
675
+ TokenToKVPoolClass = AscendTokenToKVPool
676
+ else:
677
+ TokenToKVPoolClass = MHATokenToKVPool
458
678
  self.full_kv_pool = TokenToKVPoolClass(
459
679
  size=size,
460
680
  page_size=self.page_size,
@@ -465,6 +685,93 @@ class SWAKVPool(KVCache):
465
685
  device=device,
466
686
  enable_memory_saver=False,
467
687
  )
688
+ self.full_attention_layer_id_mapping = {
689
+ id: i for i, id in enumerate(full_attention_layer_ids)
690
+ }
691
+ k_size, v_size = self.get_kv_size_bytes()
692
+ self.mem_usage = (k_size + v_size) / GB
693
+
694
+ def get_kv_size_bytes(self):
695
+ return self.full_kv_pool.get_kv_size_bytes()
696
+
697
+ def get_contiguous_buf_infos(self):
698
+ return self.full_kv_pool.get_contiguous_buf_infos()
699
+
700
+ def _transfer_full_attention_id(self, layer_id: int):
701
+ if layer_id not in self.full_attention_layer_id_mapping:
702
+ raise ValueError(
703
+ f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
704
+ )
705
+ return self.full_attention_layer_id_mapping[layer_id]
706
+
707
+ def get_key_buffer(self, layer_id: int):
708
+ layer_id = self._transfer_full_attention_id(layer_id)
709
+ return self.full_kv_pool.get_key_buffer(layer_id)
710
+
711
+ def get_value_buffer(self, layer_id: int):
712
+ layer_id = self._transfer_full_attention_id(layer_id)
713
+ return self.full_kv_pool.get_value_buffer(layer_id)
714
+
715
+ def get_kv_buffer(self, layer_id: int):
716
+ layer_id = self._transfer_full_attention_id(layer_id)
717
+ return self.full_kv_pool.get_kv_buffer(layer_id)
718
+
719
+ def set_kv_buffer(
720
+ self,
721
+ layer: RadixAttention,
722
+ loc: torch.Tensor,
723
+ cache_k: torch.Tensor,
724
+ cache_v: torch.Tensor,
725
+ k_scale: float = 1.0,
726
+ v_scale: float = 1.0,
727
+ ):
728
+ layer_id = self._transfer_full_attention_id(layer.layer_id)
729
+ self.full_kv_pool.set_kv_buffer(
730
+ None,
731
+ loc,
732
+ cache_k,
733
+ cache_v,
734
+ k_scale,
735
+ v_scale,
736
+ layer_id_override=layer_id,
737
+ )
738
+
739
+ def get_v_head_dim(self):
740
+ return self.full_kv_pool.get_value_buffer(0).shape[-1]
741
+
742
+
743
+ class SWAKVPool(KVCache):
744
+ """KV cache with separate pools for full and SWA attention layers."""
745
+
746
+ def __init__(
747
+ self,
748
+ size: int,
749
+ size_swa: int,
750
+ swa_attention_layer_ids: List[int],
751
+ full_attention_layer_ids: List[int],
752
+ enable_kvcache_transpose: bool,
753
+ token_to_kv_pool_class: KVCache = MHATokenToKVPool,
754
+ **kwargs,
755
+ ):
756
+ self.size = size
757
+ self.size_swa = size_swa
758
+ self.swa_layer_nums = len(swa_attention_layer_ids)
759
+ self.full_layer_nums = len(full_attention_layer_ids)
760
+ kwargs["page_size"] = 1
761
+ kwargs["enable_memory_saver"] = False
762
+ # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
763
+ assert not enable_kvcache_transpose
764
+
765
+ self.swa_kv_pool = token_to_kv_pool_class(
766
+ size=size_swa,
767
+ layer_num=self.swa_layer_nums,
768
+ **kwargs,
769
+ )
770
+ self.full_kv_pool = token_to_kv_pool_class(
771
+ size=size,
772
+ layer_num=self.full_layer_nums,
773
+ **kwargs,
774
+ )
468
775
  self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
469
776
  for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
470
777
  self.layers_mapping[global_layer_id] = (full_attn_layer_id, False)
@@ -613,8 +920,12 @@ class AscendTokenToKVPool(MHATokenToKVPool):
613
920
  cache_v: torch.Tensor,
614
921
  k_scale: Optional[float] = None,
615
922
  v_scale: Optional[float] = None,
923
+ layer_id_override: Optional[int] = None,
616
924
  ):
617
- layer_id = layer.layer_id
925
+ if layer_id_override is not None:
926
+ layer_id = layer_id_override
927
+ else:
928
+ layer_id = layer.layer_id
618
929
  if cache_k.dtype != self.dtype:
619
930
  if k_scale is not None:
620
931
  cache_k.div_(k_scale)
@@ -768,19 +1079,13 @@ class MLATokenToKVPool(KVCache):
768
1079
  dtype=torch.uint64,
769
1080
  device=self.device,
770
1081
  )
771
- self.layer_transfer_counter = None
772
-
773
- kv_size = self.get_kv_size_bytes()
774
- logger.info(
775
- f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
776
- )
777
- self.mem_usage = kv_size / GB
1082
+ self._finalize_allocation_log(size)
778
1083
 
779
1084
  def get_kv_size_bytes(self):
780
1085
  assert hasattr(self, "kv_buffer")
781
1086
  kv_size_bytes = 0
782
1087
  for kv_cache in self.kv_buffer:
783
- kv_size_bytes += np.prod(kv_cache.shape) * kv_cache.dtype.itemsize
1088
+ kv_size_bytes += get_tensor_size_bytes(kv_cache)
784
1089
  return kv_size_bytes
785
1090
 
786
1091
  # for disagg
@@ -936,22 +1241,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
936
1241
  device=self.device,
937
1242
  )
938
1243
 
939
- self.layer_transfer_counter = None
940
-
941
- kv_size = self.get_kv_size_bytes()
942
- logger.info(
943
- f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
944
- )
945
- self.mem_usage = kv_size / GB
1244
+ self._finalize_allocation_log(size)
946
1245
 
947
1246
  def get_kv_size_bytes(self):
948
1247
  assert hasattr(self, "k_buffer")
949
1248
  assert hasattr(self, "v_buffer")
950
1249
  kv_size_bytes = 0
951
1250
  for k_cache in self.k_buffer:
952
- kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
1251
+ kv_size_bytes += get_tensor_size_bytes(k_cache)
953
1252
  for v_cache in self.v_buffer:
954
- kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
1253
+ kv_size_bytes += get_tensor_size_bytes(v_cache)
955
1254
  return kv_size_bytes
956
1255
 
957
1256
  def get_kv_buffer(self, layer_id: int):