sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ from __future__ import annotations
17
+
16
18
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
17
19
 
18
20
  """
@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
27
29
  import abc
28
30
  import logging
29
31
  from contextlib import nullcontext
30
- from typing import Dict, List, Optional, Tuple, Union
32
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
31
33
 
32
34
  import numpy as np
33
35
  import torch
@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
38
40
  from sglang.srt.layers.radix_attention import RadixAttention
39
41
  from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
40
42
 
43
+ if TYPE_CHECKING:
44
+ from sglang.srt.managers.cache_controller import LayerDoneCounter
45
+
41
46
  logger = logging.getLogger(__name__)
42
47
 
43
48
  GB = 1024 * 1024 * 1024
@@ -97,6 +102,207 @@ class ReqToTokenPool:
97
102
  self.free_slots = list(range(self.size))
98
103
 
99
104
 
105
+ class MambaPool:
106
+ def __init__(
107
+ self,
108
+ size: int,
109
+ conv_dtype: torch.dtype,
110
+ ssm_dtype: torch.dtype,
111
+ num_mamba_layers: int,
112
+ conv_state_shape: Tuple[int, int],
113
+ temporal_state_shape: Tuple[int, int],
114
+ device: str,
115
+ speculative_num_draft_tokens: Optional[int] = None,
116
+ ):
117
+ conv_state = torch.zeros(
118
+ size=(num_mamba_layers, size + 1) + conv_state_shape,
119
+ dtype=conv_dtype,
120
+ device=device,
121
+ )
122
+ temporal_state = torch.zeros(
123
+ size=(num_mamba_layers, size + 1) + temporal_state_shape,
124
+ dtype=ssm_dtype,
125
+ device=device,
126
+ )
127
+ if speculative_num_draft_tokens is not None:
128
+ # Cache intermediate SSM states per draft token during target verify
129
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
130
+ intermediate_ssm_state_cache = torch.empty(
131
+ size=(
132
+ num_mamba_layers,
133
+ size + 1,
134
+ speculative_num_draft_tokens,
135
+ temporal_state_shape[0],
136
+ temporal_state_shape[1],
137
+ temporal_state_shape[2],
138
+ ),
139
+ dtype=ssm_dtype,
140
+ device="cuda",
141
+ )
142
+ # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
143
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
144
+ intermediate_conv_window_cache = torch.empty(
145
+ size=(
146
+ num_mamba_layers,
147
+ size + 1,
148
+ speculative_num_draft_tokens,
149
+ conv_state_shape[0],
150
+ conv_state_shape[1],
151
+ ),
152
+ dtype=conv_dtype,
153
+ device="cuda",
154
+ )
155
+ self.mamba_cache = (
156
+ conv_state,
157
+ temporal_state,
158
+ intermediate_ssm_state_cache,
159
+ intermediate_conv_window_cache,
160
+ )
161
+ else:
162
+ self.mamba_cache = (conv_state, temporal_state)
163
+ self.size = size
164
+ self.free_slots = list(range(size))
165
+ self.mem_usage = self.get_mamba_size() / GB
166
+ logger.info(
167
+ f"Mamba Cache is allocated. "
168
+ f"conv_state size: {conv_state.numel() * conv_state.itemsize / GB:.2f}GB, "
169
+ f"ssm_state size: {temporal_state.numel() * temporal_state.itemsize / GB:.2f}GB "
170
+ )
171
+
172
+ def get_mamba_params_all_layers(self):
173
+ return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
174
+
175
+ def get_mamba_params(self, layer_id: int):
176
+ return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
177
+
178
+ def get_mamba_size(self):
179
+ return (
180
+ np.prod(self.mamba_cache[0].shape) * self.mamba_cache[0].dtype.itemsize
181
+ + np.prod(self.mamba_cache[1].shape) * self.mamba_cache[1].dtype.itemsize
182
+ )
183
+
184
+ def available_size(self):
185
+ return len(self.free_slots)
186
+
187
+ def alloc(self, need_size: int) -> Optional[List[int]]:
188
+ if need_size > len(self.free_slots):
189
+ return None
190
+
191
+ select_index = self.free_slots[:need_size]
192
+ self.free_slots = self.free_slots[need_size:]
193
+
194
+ return select_index
195
+
196
+ def free(self, free_index: Union[int, List[int]]):
197
+ if isinstance(free_index, (int,)):
198
+ self.free_slots.append(free_index)
199
+ else:
200
+ self.free_slots.extend(free_index)
201
+ self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
202
+
203
+ def clear(self):
204
+ self.free_slots = list(range(self.size))
205
+
206
+
207
+ class HybridReqToTokenPool(ReqToTokenPool):
208
+ """A memory pool that maps a request to its token locations."""
209
+
210
+ def __init__(
211
+ self,
212
+ size: int,
213
+ max_context_len: int,
214
+ device: str,
215
+ enable_memory_saver: bool,
216
+ conv_dtype: torch.dtype,
217
+ ssm_dtype: torch.dtype,
218
+ mamba_layers: List[int],
219
+ conv_state_shape: Tuple[int, int],
220
+ temporal_state_shape: Tuple[int, int],
221
+ speculative_num_draft_tokens: int,
222
+ ):
223
+ super().__init__(
224
+ size=size,
225
+ max_context_len=max_context_len,
226
+ device=device,
227
+ enable_memory_saver=enable_memory_saver,
228
+ )
229
+
230
+ self.mamba_pool = MambaPool(
231
+ size,
232
+ conv_dtype,
233
+ ssm_dtype,
234
+ len(mamba_layers),
235
+ conv_state_shape,
236
+ temporal_state_shape,
237
+ device,
238
+ speculative_num_draft_tokens,
239
+ )
240
+ self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
241
+
242
+ self.device = device
243
+ self.req_index_to_mamba_index_mapping: torch.Tensor = torch.empty(
244
+ size, dtype=torch.int32, device=self.device
245
+ )
246
+
247
+ self.rid_to_mamba_index_mapping: Dict[str, int] = {}
248
+ self.mamba_index_to_rid_mapping: Dict[int, str] = {}
249
+
250
+ # For chunk prefill req, we do not need to allocate mamba cache,
251
+ # We could use allocated mamba cache instead.
252
+ def alloc(
253
+ self, need_size: int, reqs: Optional[List["Req"]] = None
254
+ ) -> Optional[List[int]]:
255
+ select_index = super().alloc(need_size)
256
+ if select_index == None:
257
+ return None
258
+
259
+ mamba_index = []
260
+ for req in reqs:
261
+ rid = req.rid
262
+ if rid in self.rid_to_mamba_index_mapping:
263
+ mid = self.rid_to_mamba_index_mapping[rid]
264
+ elif (mid := self.mamba_pool.alloc(1)) is not None:
265
+ mid = mid[0]
266
+ self.rid_to_mamba_index_mapping[rid] = mid
267
+ self.mamba_index_to_rid_mapping[mid] = rid
268
+ mamba_index.append(mid)
269
+ assert len(select_index) == len(
270
+ mamba_index
271
+ ), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
272
+ self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
273
+ mamba_index, dtype=torch.int32, device=self.device
274
+ )
275
+ return select_index
276
+
277
+ def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
278
+ return self.req_index_to_mamba_index_mapping[req_indices]
279
+
280
+ def get_mamba_params(self, layer_id: int):
281
+ assert layer_id in self.mamba_map
282
+ return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
283
+
284
+ def get_mamba_params_all_layers(self):
285
+ return self.mamba_pool.get_mamba_params_all_layers()
286
+
287
+ # For chunk prefill, we can not free mamba cache, we need use it in the future
288
+ def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
289
+ super().free(free_index)
290
+ if free_mamba_cache:
291
+ mamba_index = self.req_index_to_mamba_index_mapping[free_index]
292
+ mamba_index_list = mamba_index.tolist()
293
+ if isinstance(mamba_index_list, int):
294
+ mamba_index_list = [mamba_index_list]
295
+ self.mamba_pool.free(mamba_index_list)
296
+ for mid in mamba_index_list:
297
+ rid = self.mamba_index_to_rid_mapping[mid]
298
+ self.mamba_index_to_rid_mapping.pop(mid)
299
+ self.rid_to_mamba_index_mapping.pop(rid)
300
+
301
+ def clear(self):
302
+ super().clear()
303
+ self.mamba_pool.clear()
304
+
305
+
100
306
  class KVCache(abc.ABC):
101
307
  @abc.abstractmethod
102
308
  def __init__(
@@ -130,6 +336,29 @@ class KVCache(abc.ABC):
130
336
  # used for chunked cpu-offloading
131
337
  self.cpu_offloading_chunk_size = 8192
132
338
 
339
+ # default state for optional layer-wise transfer control
340
+ self.layer_transfer_counter = None
341
+
342
+ def _finalize_allocation_log(self, num_tokens: int):
343
+ """Common logging and mem_usage computation for KV cache allocation.
344
+ Supports both tuple (K, V) size returns and single KV size returns.
345
+ """
346
+ kv_size_bytes = self.get_kv_size_bytes()
347
+ if isinstance(kv_size_bytes, tuple):
348
+ k_size, v_size = kv_size_bytes
349
+ k_size_GB = k_size / GB
350
+ v_size_GB = v_size / GB
351
+ logger.info(
352
+ f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB"
353
+ )
354
+ self.mem_usage = k_size_GB + v_size_GB
355
+ else:
356
+ kv_size_GB = kv_size_bytes / GB
357
+ logger.info(
358
+ f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB"
359
+ )
360
+ self.mem_usage = kv_size_GB
361
+
133
362
  @abc.abstractmethod
134
363
  def get_key_buffer(self, layer_id: int) -> torch.Tensor:
135
364
  raise NotImplementedError()
@@ -152,7 +381,7 @@ class KVCache(abc.ABC):
152
381
  ) -> None:
153
382
  raise NotImplementedError()
154
383
 
155
- def register_layer_transfer_counter(self, layer_transfer_counter):
384
+ def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
156
385
  self.layer_transfer_counter = layer_transfer_counter
157
386
 
158
387
  def get_cpu_copy(self, indices):
@@ -205,15 +434,9 @@ class MHATokenToKVPool(KVCache):
205
434
 
206
435
  self._create_buffers()
207
436
 
208
- self.layer_transfer_counter = None
209
437
  self.device_module = torch.get_device_module(self.device)
210
438
  self.alt_stream = self.device_module.Stream() if _is_cuda else None
211
-
212
- k_size, v_size = self.get_kv_size_bytes()
213
- logger.info(
214
- f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
215
- )
216
- self.mem_usage = (k_size + v_size) / GB
439
+ self._finalize_allocation_log(size)
217
440
 
218
441
  def _create_buffers(self):
219
442
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
@@ -352,7 +575,6 @@ class MHATokenToKVPool(KVCache):
352
575
  # same applies to get_value_buffer and get_kv_buffer
353
576
  if self.layer_transfer_counter is not None:
354
577
  self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
355
-
356
578
  return self._get_key_buffer(layer_id)
357
579
 
358
580
  def _get_value_buffer(self, layer_id: int):
@@ -420,50 +642,119 @@ class MHATokenToKVPool(KVCache):
420
642
  )
421
643
 
422
644
 
423
- class SWAKVPool(KVCache):
424
- """KV cache with separate pools for full and SWA attention layers."""
645
+ class HybridLinearKVPool(KVCache):
646
+ """KV cache with separate pools for full and linear attention layers."""
425
647
 
426
648
  def __init__(
427
649
  self,
428
650
  size: int,
429
- size_swa: int,
430
651
  dtype: torch.dtype,
431
652
  head_num: int,
432
653
  head_dim: int,
433
- swa_attention_layer_ids: List[int],
434
654
  full_attention_layer_ids: List[int],
435
655
  enable_kvcache_transpose: bool,
436
656
  device: str,
437
657
  ):
438
658
  self.size = size
439
- self.size_swa = size_swa
440
659
  self.dtype = dtype
441
660
  self.device = device
442
- self.swa_layer_nums = len(swa_attention_layer_ids)
443
661
  self.full_layer_nums = len(full_attention_layer_ids)
444
662
  self.page_size = 1
445
663
  # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
446
664
  assert not enable_kvcache_transpose
447
- TokenToKVPoolClass = MHATokenToKVPool
448
- self.swa_kv_pool = TokenToKVPoolClass(
449
- size=size_swa,
665
+ self.full_kv_pool = MHATokenToKVPool(
666
+ size=size,
450
667
  page_size=self.page_size,
451
668
  dtype=dtype,
452
669
  head_num=head_num,
453
670
  head_dim=head_dim,
454
- layer_num=self.swa_layer_nums,
671
+ layer_num=self.full_layer_nums,
455
672
  device=device,
456
673
  enable_memory_saver=False,
457
674
  )
458
- self.full_kv_pool = TokenToKVPoolClass(
675
+ self.full_attention_layer_id_mapping = {
676
+ id: i for i, id in enumerate(full_attention_layer_ids)
677
+ }
678
+ k_size, v_size = self.get_kv_size_bytes()
679
+ self.mem_usage = (k_size + v_size) / GB
680
+
681
+ def get_kv_size_bytes(self):
682
+ return self.full_kv_pool.get_kv_size_bytes()
683
+
684
+ def get_contiguous_buf_infos(self):
685
+ return self.full_kv_pool.get_contiguous_buf_infos()
686
+
687
+ def _transfer_full_attention_id(self, layer_id: int):
688
+ if layer_id not in self.full_attention_layer_id_mapping:
689
+ raise ValueError(
690
+ f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
691
+ )
692
+ return self.full_attention_layer_id_mapping[layer_id]
693
+
694
+ def get_key_buffer(self, layer_id: int):
695
+ layer_id = self._transfer_full_attention_id(layer_id)
696
+ return self.full_kv_pool.get_key_buffer(layer_id)
697
+
698
+ def get_value_buffer(self, layer_id: int):
699
+ layer_id = self._transfer_full_attention_id(layer_id)
700
+ return self.full_kv_pool.get_value_buffer(layer_id)
701
+
702
+ def get_kv_buffer(self, layer_id: int):
703
+ layer_id = self._transfer_full_attention_id(layer_id)
704
+ return self.full_kv_pool.get_kv_buffer(layer_id)
705
+
706
+ def set_kv_buffer(
707
+ self,
708
+ layer: RadixAttention,
709
+ loc: torch.Tensor,
710
+ cache_k: torch.Tensor,
711
+ cache_v: torch.Tensor,
712
+ k_scale: float = 1.0,
713
+ v_scale: float = 1.0,
714
+ ):
715
+ layer_id = self._transfer_full_attention_id(layer.layer_id)
716
+ self.full_kv_pool.set_kv_buffer(
717
+ None,
718
+ loc,
719
+ cache_k,
720
+ cache_v,
721
+ k_scale,
722
+ v_scale,
723
+ layer_id_override=layer_id,
724
+ )
725
+
726
+
727
+ class SWAKVPool(KVCache):
728
+ """KV cache with separate pools for full and SWA attention layers."""
729
+
730
+ def __init__(
731
+ self,
732
+ size: int,
733
+ size_swa: int,
734
+ swa_attention_layer_ids: List[int],
735
+ full_attention_layer_ids: List[int],
736
+ enable_kvcache_transpose: bool,
737
+ token_to_kv_pool_class: KVCache = MHATokenToKVPool,
738
+ **kwargs,
739
+ ):
740
+ self.size = size
741
+ self.size_swa = size_swa
742
+ self.swa_layer_nums = len(swa_attention_layer_ids)
743
+ self.full_layer_nums = len(full_attention_layer_ids)
744
+ kwargs["page_size"] = 1
745
+ kwargs["enable_memory_saver"] = False
746
+ # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
747
+ assert not enable_kvcache_transpose
748
+
749
+ self.swa_kv_pool = token_to_kv_pool_class(
750
+ size=size_swa,
751
+ layer_num=self.swa_layer_nums,
752
+ **kwargs,
753
+ )
754
+ self.full_kv_pool = token_to_kv_pool_class(
459
755
  size=size,
460
- page_size=self.page_size,
461
- dtype=dtype,
462
- head_num=head_num,
463
- head_dim=head_dim,
464
756
  layer_num=self.full_layer_nums,
465
- device=device,
466
- enable_memory_saver=False,
757
+ **kwargs,
467
758
  )
468
759
  self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
469
760
  for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
@@ -768,13 +1059,7 @@ class MLATokenToKVPool(KVCache):
768
1059
  dtype=torch.uint64,
769
1060
  device=self.device,
770
1061
  )
771
- self.layer_transfer_counter = None
772
-
773
- kv_size = self.get_kv_size_bytes()
774
- logger.info(
775
- f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
776
- )
777
- self.mem_usage = kv_size / GB
1062
+ self._finalize_allocation_log(size)
778
1063
 
779
1064
  def get_kv_size_bytes(self):
780
1065
  assert hasattr(self, "kv_buffer")
@@ -918,6 +1203,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
918
1203
  layer_num,
919
1204
  self.size // self.page_size + 1,
920
1205
  self.page_size,
1206
+ 1,
921
1207
  self.kv_lora_rank,
922
1208
  ),
923
1209
  dtype=self.store_dtype,
@@ -928,19 +1214,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
928
1214
  layer_num,
929
1215
  self.size // self.page_size + 1,
930
1216
  self.page_size,
1217
+ 1,
931
1218
  self.qk_rope_head_dim,
932
1219
  ),
933
1220
  dtype=self.store_dtype,
934
1221
  device=self.device,
935
1222
  )
936
1223
 
937
- self.layer_transfer_counter = None
938
-
939
- kv_size = self.get_kv_size_bytes()
940
- logger.info(
941
- f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
942
- )
943
- self.mem_usage = kv_size / GB
1224
+ self._finalize_allocation_log(size)
944
1225
 
945
1226
  def get_kv_size_bytes(self):
946
1227
  assert hasattr(self, "k_buffer")
@@ -1000,9 +1281,11 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1000
1281
  layer_id = layer.layer_id
1001
1282
  if cache_k.dtype != self.dtype:
1002
1283
  cache_k = cache_k.to(self.dtype)
1284
+ cache_v = cache_v.to(self.dtype)
1003
1285
 
1004
1286
  if self.store_dtype != self.dtype:
1005
1287
  cache_k = cache_k.view(self.store_dtype)
1288
+ cache_v = cache_v.view(self.store_dtype)
1006
1289
 
1007
1290
  if cache_v is None:
1008
1291
  cache_k, cache_v = cache_k.split(
@@ -3,16 +3,17 @@ import logging
3
3
  import threading
4
4
  from enum import IntEnum
5
5
  from functools import wraps
6
+ from typing import Optional
6
7
 
7
8
  import psutil
8
9
  import torch
9
10
 
10
- from sglang.srt.distributed import get_tensor_model_parallel_rank
11
11
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
12
- from sglang.srt.utils import is_npu
12
+ from sglang.srt.utils import is_npu, is_xpu
13
13
 
14
14
  _is_npu = is_npu()
15
- if not _is_npu:
15
+ _is_xpu = is_xpu()
16
+ if not (_is_npu or _is_xpu):
16
17
  from sgl_kernel.kvcacheio import (
17
18
  transfer_kv_all_layer,
18
19
  transfer_kv_all_layer_lf_pf,
@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC):
169
170
  return len(self.free_slots)
170
171
 
171
172
  @synchronized()
172
- def alloc(self, need_size: int) -> torch.Tensor:
173
+ def alloc(self, need_size: int) -> Optional[torch.Tensor]:
173
174
  assert (
174
175
  need_size % self.page_size == 0
175
176
  ), "The requested size should be a multiple of the page size."
@@ -464,11 +465,11 @@ class MHATokenToKVPoolHost(HostKVCache):
464
465
  else:
465
466
  raise ValueError(f"Unsupported layout: {self.layout}")
466
467
 
467
- def get_buffer_meta(self, keys, indices):
468
- local_rank = get_tensor_model_parallel_rank()
468
+ def get_buffer_meta(self, keys, indices, local_rank):
469
469
  ptr_list = []
470
470
  key_list = []
471
471
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
472
+ indices = indices.tolist()
472
473
  v_offset = (
473
474
  self.layer_num
474
475
  * self.size
@@ -501,20 +502,23 @@ class MHATokenToKVPoolHost(HostKVCache):
501
502
  element_size_list = [element_size] * len(key_list)
502
503
  return key_list, ptr_list, element_size_list
503
504
 
504
- def get_buffer_with_hash(self, keys, indices):
505
+ def get_buffer_with_hash(self, keys, indices=None):
505
506
  assert self.layout == "page_first"
506
- assert len(keys) == (len(indices) // self.page_size)
507
+ assert indices is None or (len(keys) == (len(indices) // self.page_size))
507
508
 
508
509
  key_list = []
509
510
  buf_list = []
510
511
 
511
- for key, i in zip(keys, range(0, len(indices), self.page_size)):
512
+ for i in range(len(keys)):
513
+ key = keys[i]
512
514
  key_list.append(f"{key}-k")
513
- buf_list.append(self.k_buffer[i : i + self.page_size])
514
515
  key_list.append(f"{key}-v")
515
- buf_list.append(self.v_buffer[i : i + self.page_size])
516
+ if indices is not None:
517
+ index = indices[i * self.page_size]
518
+ buf_list.append(self.k_buffer[index : index + self.page_size])
519
+ buf_list.append(self.v_buffer[index : index + self.page_size])
516
520
 
517
- return key_list, buf_list
521
+ return key_list, buf_list, 2
518
522
 
519
523
 
520
524
  class MLATokenToKVPoolHost(HostKVCache):
@@ -704,10 +708,11 @@ class MLATokenToKVPoolHost(HostKVCache):
704
708
  else:
705
709
  raise ValueError(f"Unsupported layout: {self.layout}")
706
710
 
707
- def get_buffer_meta(self, keys, indices):
711
+ def get_buffer_meta(self, keys, indices, local_rank):
708
712
  ptr_list = []
709
713
  key_list = []
710
714
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
715
+ indices = indices.tolist()
711
716
  for index in range(0, len(indices), self.page_size):
712
717
  k_ptr = (
713
718
  kv_buffer_data_ptr
@@ -728,13 +733,15 @@ class MLATokenToKVPoolHost(HostKVCache):
728
733
  element_size_list = [element_size] * len(key_list)
729
734
  return key_list, ptr_list, element_size_list
730
735
 
731
- def get_buffer_with_hash(self, keys, indices):
736
+ def get_buffer_with_hash(self, keys, indices=None):
732
737
  assert self.layout == "page_first"
733
- assert len(keys) == (len(indices) // self.page_size)
738
+ assert indices is None or (len(keys) == (len(indices) // self.page_size))
734
739
 
735
740
  buf_list = []
736
741
 
737
- for i in range(0, len(indices), self.page_size):
738
- buf_list.append(self.kv_buffer[i : i + self.page_size])
742
+ if indices is not None:
743
+ for i in range(len(keys)):
744
+ index = indices[i * self.page_size]
745
+ buf_list.append(self.kv_buffer[index : index + self.page_size])
739
746
 
740
- return keys, buf_list
747
+ return keys, buf_list, 1
@@ -53,8 +53,6 @@ class TreeNode:
53
53
  self.last_access_time = time.monotonic()
54
54
 
55
55
  self.hit_count = 0
56
- # indicating the node is loading KV cache from host
57
- self.loading = False
58
56
  # indicating the node is locked to protect from eviction
59
57
  # incremented when the node is referenced by a storage operation
60
58
  self.host_ref_counter = 0
@@ -62,7 +60,6 @@ class TreeNode:
62
60
  self.host_value: Optional[torch.Tensor] = None
63
61
  # store hash values of each pages
64
62
  self.hash_value: Optional[List[str]] = None
65
- self.backuped_storage = False
66
63
 
67
64
  self.id = TreeNode.counter if id is None else id
68
65
  TreeNode.counter += 1
@@ -195,7 +192,7 @@ class RadixCache(BasePrefixCache):
195
192
  last_host_node=last_node,
196
193
  )
197
194
 
198
- def insert(self, key: List, value=None):
195
+ def insert(self, key: List, value=None, chunked=False):
199
196
  if self.disable:
200
197
  return 0
201
198
 
@@ -240,7 +237,7 @@ class RadixCache(BasePrefixCache):
240
237
  self.req_to_token_pool.free(req.req_pool_idx)
241
238
  self.dec_lock_ref(req.last_node)
242
239
 
243
- def cache_unfinished_req(self, req: Req):
240
+ def cache_unfinished_req(self, req: Req, chunked=False):
244
241
  """Cache request when it is unfinished."""
245
242
  if self.disable:
246
243
  return
@@ -261,7 +258,9 @@ class RadixCache(BasePrefixCache):
261
258
  page_aligned_token_ids = token_ids[:page_aligned_len]
262
259
 
263
260
  # Radix Cache takes one ref in memory pool
264
- new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
261
+ new_prefix_len = self.insert(
262
+ page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
263
+ )
265
264
  self.token_to_kv_pool_allocator.free(
266
265
  kv_indices[len(req.prefix_indices) : new_prefix_len]
267
266
  )
@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
181
181
  self.dec_lock_ref(req.last_node)
182
182
  self.req_to_token_pool.free(req.req_pool_idx)
183
183
 
184
- def cache_unfinished_req(self, req: Req):
184
+ def cache_unfinished_req(self, req: Req, chunked=False):
185
185
  """Cache request when it is unfinished."""
186
186
  assert req.req_pool_idx is not None
187
187
  token_ids = req.fill_ids