sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -100,7 +100,7 @@ class TpModelWorkerClient:
100
100
  def get_memory_pool(self):
101
101
  return (
102
102
  self.worker.model_runner.req_to_token_pool,
103
- self.worker.model_runner.token_to_kv_pool,
103
+ self.worker.model_runner.token_to_kv_pool_allocator,
104
104
  )
105
105
 
106
106
  def forward_thread_func(self):
@@ -175,7 +175,7 @@ class TpModelWorkerClient:
175
175
  logits_output.next_token_logprobs.tolist()
176
176
  )
177
177
  if logits_output.input_token_logprobs is not None:
178
- logits_output.input_token_logprobs = (
178
+ logits_output.input_token_logprobs = tuple(
179
179
  logits_output.input_token_logprobs.tolist()
180
180
  )
181
181
  next_token_ids = next_token_ids.tolist()
@@ -188,8 +188,7 @@ class TpModelWorkerClient:
188
188
  model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
189
189
  sampling_info,
190
190
  sampling_info_done=threading.Event(),
191
- scaling_penalties=sampling_info.scaling_penalties,
192
- linear_penalties=sampling_info.linear_penalties,
191
+ penalizer_orchestrator=None,
193
192
  )
194
193
 
195
194
  # A cuda stream sync here to avoid the cuda illegal memory access error.
@@ -1,29 +1,33 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
+ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
4
5
 
5
- from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
6
+ import torch
6
7
 
7
8
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
8
- from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
9
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
9
10
 
10
11
  if TYPE_CHECKING:
11
12
  from sglang.srt.managers.schedule_batch import Req
12
13
 
13
14
 
14
15
  class ChunkCacheEntry:
15
- def __init__(self, rid, value):
16
+ def __init__(self, rid: str, value: torch.Tensor):
16
17
  self.rid = rid
17
18
  self.value = value
18
19
 
19
20
 
20
21
  class ChunkCache(BasePrefixCache):
21
22
  def __init__(
22
- self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool
23
+ self,
24
+ req_to_token_pool: ReqToTokenPool,
25
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator,
23
26
  ):
24
27
  self.disable = True
25
28
  self.req_to_token_pool = req_to_token_pool
26
- self.token_to_kv_pool = token_to_kv_pool
29
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
30
+ self.entries: Dict[str, ChunkCacheEntry] = {}
27
31
 
28
32
  self.reset()
29
33
 
@@ -48,16 +52,13 @@ class ChunkCache(BasePrefixCache):
48
52
  req.req_pool_idx, :token_id_len
49
53
  ]
50
54
  self.req_to_token_pool.free(req.req_pool_idx)
51
- self.token_to_kv_pool.free(kv_indices)
55
+ self.token_to_kv_pool_allocator.free(kv_indices)
52
56
 
53
57
  if req.rid in self.entries:
54
58
  del self.entries[req.rid]
55
59
 
56
- def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
57
- if token_ids is None:
58
- token_id_len = len(req.fill_ids)
59
- else:
60
- token_id_len = len(token_ids)
60
+ def cache_unfinished_req(self, req: Req):
61
+ token_id_len = len(req.fill_ids)
61
62
 
62
63
  kv_indices = self.req_to_token_pool.req_to_token[
63
64
  req.req_pool_idx, :token_id_len
@@ -86,5 +87,11 @@ class ChunkCache(BasePrefixCache):
86
87
  def evictable_size(self):
87
88
  return 0
88
89
 
90
+ def pretty_print(self):
91
+ return ""
92
+
89
93
  def protected_size(self):
90
94
  return 0
95
+
96
+ def pretty_print(self):
97
+ return ""
@@ -0,0 +1,394 @@
1
+ import heapq
2
+ import logging
3
+ import time
4
+ from typing import List, Optional
5
+
6
+ import torch
7
+
8
+ from sglang.srt.managers.cache_controller import HiCacheController
9
+ from sglang.srt.mem_cache.memory_pool import (
10
+ MHATokenToKVPool,
11
+ MHATokenToKVPoolHost,
12
+ ReqToTokenPool,
13
+ )
14
+ from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class HiRadixCache(RadixCache):
20
+
21
+ def __init__(
22
+ self,
23
+ req_to_token_pool: ReqToTokenPool,
24
+ token_to_kv_pool: MHATokenToKVPool,
25
+ ):
26
+ self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
27
+ self.cache_controller = HiCacheController(
28
+ token_to_kv_pool, self.token_to_kv_pool_host
29
+ )
30
+
31
+ # record the nodes with ongoing write through
32
+ self.ongoing_write_through = {}
33
+ # record the node segments with ongoing load back
34
+ self.ongoing_load_back = {}
35
+ # todo: dynamically adjust the threshold
36
+ self.write_through_threshold = 1
37
+ self.load_back_threshold = 10
38
+ super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
39
+
40
+ def reset(self):
41
+ TreeNode.counter = 0
42
+ self.cache_controller.reset()
43
+ self.token_to_kv_pool_host.clear()
44
+ super().reset()
45
+
46
+ def get_height(self, node: TreeNode):
47
+ height = 0
48
+ while node != self.root_node:
49
+ node = node.parent
50
+ height += 1
51
+ return height
52
+
53
+ def write_backup(self, node: TreeNode):
54
+ host_indices = self.cache_controller.write(
55
+ device_indices=node.value,
56
+ priority=-self.get_height(node),
57
+ node_id=node.id,
58
+ )
59
+ if host_indices is None:
60
+ self.evict_host(len(node.value))
61
+ host_indices = self.cache_controller.write(
62
+ device_indices=node.value,
63
+ priority=-self.get_height(node),
64
+ node_id=node.id,
65
+ )
66
+ if host_indices is not None:
67
+ node.host_value = host_indices
68
+ self.ongoing_write_through[node.id] = node
69
+ self.inc_lock_ref(node)
70
+ else:
71
+ return None
72
+
73
+ return len(host_indices)
74
+
75
+ def inc_hit_count(self, node: TreeNode):
76
+ if self.cache_controller.write_policy != "write_through_selective":
77
+ return
78
+ node.hit_count += 1
79
+ if node.host_value is None and node.hit_count > self.write_through_threshold:
80
+ self.write_backup(node)
81
+ node.hit_count = 0
82
+
83
+ def writing_check(self):
84
+ while not self.cache_controller.ack_write_queue.empty():
85
+ try:
86
+ ack_id = self.cache_controller.ack_write_queue.get_nowait()
87
+ self.dec_lock_ref(self.ongoing_write_through[ack_id])
88
+ # clear the reference
89
+ del self.ongoing_write_through[ack_id]
90
+ except Exception:
91
+ break
92
+
93
+ def loading_check(self):
94
+ while not self.cache_controller.ack_load_queue.empty():
95
+ try:
96
+ ack_id = self.cache_controller.ack_load_queue.get_nowait()
97
+ start_node, end_node = self.ongoing_load_back[ack_id]
98
+ self.dec_lock_ref(end_node)
99
+ while end_node != start_node:
100
+ assert end_node.loading
101
+ end_node.loading = False
102
+ end_node = end_node.parent
103
+ # clear the reference
104
+ del self.ongoing_load_back[ack_id]
105
+ except Exception:
106
+ break
107
+
108
+ def evictable_size(self):
109
+ self.writing_check()
110
+ self.loading_check()
111
+ return self.evictable_size_
112
+
113
+ def evict(self, num_tokens: int, evict_callback=None):
114
+ leaves = self._collect_leaves_device()
115
+ heapq.heapify(leaves)
116
+
117
+ num_evicted = 0
118
+ pending_nodes = []
119
+ while num_evicted < num_tokens and len(leaves):
120
+ x = heapq.heappop(leaves)
121
+
122
+ if x.lock_ref > 0:
123
+ continue
124
+
125
+ if x.host_value is None:
126
+ if self.cache_controller.write_policy == "write_back":
127
+ num_evicted += self.write_backup(x)
128
+ elif self.cache_controller.write_policy == "write_through_selective":
129
+ num_evicted += self._evict_write_through_selective(x)
130
+ else:
131
+ assert (
132
+ self.cache_controller.write_policy != "write_through"
133
+ ), "write_through should be inclusive"
134
+ raise NotImplementedError
135
+ else:
136
+ num_evicted += self._evict_write_through(x)
137
+
138
+ for child in x.parent.children.values():
139
+ if child in pending_nodes:
140
+ continue
141
+ if not child.evicted:
142
+ break
143
+ else:
144
+ # all children are evicted or no children
145
+ heapq.heappush(leaves, x.parent)
146
+
147
+ if self.cache_controller.write_policy == "write_back":
148
+ # blocking till all write back complete
149
+ while len(self.ongoing_write_through) > 0:
150
+ self.writing_check()
151
+ time.sleep(0.1)
152
+
153
+ def _evict_write_through(self, node: TreeNode):
154
+ # evict a node already written to host
155
+ num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
156
+ assert num_evicted > 0
157
+ self.evictable_size_ -= num_evicted
158
+ node.value = None
159
+ return num_evicted
160
+
161
+ def _evict_write_through_selective(self, node: TreeNode):
162
+ # evict a node not initiated write to host
163
+ self.cache_controller.mem_pool_device.free(node.value)
164
+ num_evicted = len(node.value)
165
+ self._delete_leaf(node)
166
+ return num_evicted
167
+
168
+ def evict_host(self, num_tokens: int):
169
+ leaves = self._collect_leaves()
170
+ heapq.heapify(leaves)
171
+
172
+ num_evicted = 0
173
+ while num_evicted < num_tokens and len(leaves):
174
+ x = heapq.heappop(leaves)
175
+ if x == self.root_node:
176
+ break
177
+ # only evict the host value of evicted nodes
178
+ if not x.evicted:
179
+ continue
180
+ assert x.lock_ref == 0 and x.host_value is not None
181
+
182
+ assert self.cache_controller.evict_host(x.host_value) > 0
183
+ for k, v in x.parent.children.items():
184
+ if v == x:
185
+ break
186
+ del x.parent.children[k]
187
+
188
+ if len(x.parent.children) == 0 and x.parent.evicted:
189
+ heapq.heappush(leaves, x.parent)
190
+
191
+ def load_back(
192
+ self, node: TreeNode, mem_quota: Optional[int] = None
193
+ ) -> Optional[torch.Tensor]:
194
+ # todo: more loading policies
195
+
196
+ last_hit_node = node
197
+ nodes_to_load = []
198
+ while node.evicted:
199
+ assert (
200
+ node.backuped
201
+ ), "No backup available on evicted nodes, should not happen"
202
+ nodes_to_load.insert(0, node)
203
+ node = node.parent
204
+ else:
205
+ ancester_node = node
206
+
207
+ # protect the ancestor nodes from eviction
208
+ delta = self.inc_lock_ref(ancester_node)
209
+
210
+ # load it all or not at all
211
+ host_indices = torch.cat([n.host_value for n in nodes_to_load])
212
+ if len(host_indices) < self.load_back_threshold or (
213
+ len(host_indices) > mem_quota + delta if mem_quota is not None else False
214
+ ):
215
+ # skip loading back if the total size is too small or exceeding the memory quota
216
+ self.dec_lock_ref(ancester_node)
217
+ return None
218
+
219
+ device_indices = self.cache_controller.load(
220
+ host_indices=host_indices, node_id=last_hit_node.id
221
+ )
222
+ if device_indices is None:
223
+ self.evict(len(host_indices))
224
+ device_indices = self.cache_controller.load(
225
+ host_indices=host_indices, node_id=last_hit_node.id
226
+ )
227
+ self.dec_lock_ref(ancester_node)
228
+ if device_indices is None:
229
+ # no sufficient GPU memory to load back KV caches
230
+ return None
231
+
232
+ self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
233
+ offset = 0
234
+ for node in nodes_to_load:
235
+ node.value = device_indices[offset : offset + len(node.host_value)]
236
+ offset += len(node.host_value)
237
+ node.loading = True
238
+ self.evictable_size_ += len(device_indices)
239
+ self.inc_lock_ref(last_hit_node)
240
+
241
+ return device_indices
242
+
243
+ def loading_complete(self, node: TreeNode):
244
+ self.loading_check()
245
+ return node.loading == False
246
+
247
+ def init_load_back(
248
+ self,
249
+ last_node: TreeNode,
250
+ prefix_indices: torch.Tensor,
251
+ mem_quota: Optional[int] = None,
252
+ ):
253
+ assert (
254
+ len(prefix_indices) == 0 or prefix_indices.is_cuda
255
+ ), "indices of device kV caches should be on GPU"
256
+ if last_node.evicted:
257
+ loading_values = self.load_back(last_node, mem_quota)
258
+ if loading_values is not None:
259
+ prefix_indices = (
260
+ loading_values
261
+ if len(prefix_indices) == 0
262
+ else torch.cat([prefix_indices, loading_values])
263
+ )
264
+ logger.debug(
265
+ f"loading back {len(loading_values)} tokens for node {last_node.id}"
266
+ )
267
+
268
+ while last_node.evicted:
269
+ last_node = last_node.parent
270
+
271
+ return last_node, prefix_indices
272
+
273
+ def _match_prefix_helper(
274
+ self, node: TreeNode, key: List, value, last_node: TreeNode
275
+ ):
276
+ node.last_access_time = time.time()
277
+ if len(key) == 0:
278
+ return
279
+
280
+ if key[0] in node.children.keys():
281
+ child = node.children[key[0]]
282
+ prefix_len = _key_match(child.key, key)
283
+ if prefix_len < len(child.key):
284
+ new_node = self._split_node(child.key, child, prefix_len)
285
+ self.inc_hit_count(new_node)
286
+ if not new_node.evicted:
287
+ value.append(new_node.value)
288
+ last_node[0] = new_node
289
+ else:
290
+ self.inc_hit_count(child)
291
+ if not child.evicted:
292
+ value.append(child.value)
293
+ last_node[0] = child
294
+ self._match_prefix_helper(child, key[prefix_len:], value, last_node)
295
+
296
+ def _split_node(self, key, child: TreeNode, split_len: int):
297
+ # child node split into new_node -> child
298
+ new_node = TreeNode()
299
+ new_node.children = {key[split_len]: child}
300
+ new_node.parent = child.parent
301
+ new_node.lock_ref = child.lock_ref
302
+ new_node.key = child.key[:split_len]
303
+ new_node.loading = child.loading
304
+
305
+ # split value and host value if exists
306
+ if child.evicted:
307
+ new_node.value = None
308
+ else:
309
+ new_node.value = child.value[:split_len]
310
+ child.value = child.value[split_len:]
311
+ if child.host_value is not None:
312
+ new_node.host_value = child.host_value[:split_len]
313
+ child.host_value = child.host_value[split_len:]
314
+ child.parent = new_node
315
+ child.key = child.key[split_len:]
316
+ new_node.parent.children[key[0]] = new_node
317
+ return new_node
318
+
319
+ def _insert_helper(self, node: TreeNode, key: List, value):
320
+ node.last_access_time = time.time()
321
+ if len(key) == 0:
322
+ return 0
323
+
324
+ if key[0] in node.children.keys():
325
+ child = node.children[key[0]]
326
+ prefix_len = _key_match(child.key, key)
327
+
328
+ if prefix_len == len(child.key):
329
+ if child.evicted:
330
+ # change the reference if the node is evicted
331
+ # this often happens in the case of KV cache recomputation
332
+ child.value = value[:prefix_len]
333
+ self.token_to_kv_pool_host.update_synced(child.host_value)
334
+ self.evictable_size_ += len(value[:prefix_len])
335
+ return self._insert_helper(
336
+ child, key[prefix_len:], value[prefix_len:]
337
+ )
338
+ else:
339
+ self.inc_hit_count(child)
340
+ return prefix_len + self._insert_helper(
341
+ child, key[prefix_len:], value[prefix_len:]
342
+ )
343
+
344
+ # partial match, split the node
345
+ new_node = self._split_node(child.key, child, prefix_len)
346
+ if new_node.evicted:
347
+ new_node.value = value[:prefix_len]
348
+ self.token_to_kv_pool_host.update_synced(new_node.host_value)
349
+ self.evictable_size_ += len(new_node.value)
350
+ return self._insert_helper(
351
+ new_node, key[prefix_len:], value[prefix_len:]
352
+ )
353
+ else:
354
+ self.inc_hit_count(new_node)
355
+ return prefix_len + self._insert_helper(
356
+ new_node, key[prefix_len:], value[prefix_len:]
357
+ )
358
+
359
+ if len(key):
360
+ new_node = TreeNode()
361
+ new_node.parent = node
362
+ new_node.key = key
363
+ new_node.value = value
364
+ node.children[key[0]] = new_node
365
+ self.evictable_size_ += len(value)
366
+
367
+ if self.cache_controller.write_policy == "write_through":
368
+ self.write_backup(new_node)
369
+ return 0
370
+
371
+ def _collect_leaves_device(self):
372
+ def is_leaf(node):
373
+ if node.evicted:
374
+ return False
375
+ if node == self.root_node:
376
+ return False
377
+ if len(node.children) == 0:
378
+ return True
379
+ for child in node.children.values():
380
+ if not child.evicted:
381
+ return False
382
+ return True
383
+
384
+ ret_list = []
385
+ stack = [self.root_node]
386
+ while stack:
387
+ cur_node = stack.pop()
388
+ if is_leaf(cur_node):
389
+ ret_list.append(cur_node)
390
+ else:
391
+ for cur_child in cur_node.children.values():
392
+ if not cur_child.evicted:
393
+ stack.append(cur_child)
394
+ return ret_list
@@ -20,9 +20,12 @@ Memory pool.
20
20
 
21
21
  SGLang has two levels of memory pool.
22
22
  ReqToTokenPool maps a a request to its token locations.
23
- BaseTokenToKVPool maps a token location to its KV cache data.
23
+ TokenToKVPoolAllocator maps a token location to its KV cache data.
24
+ KVCache actually holds the physical kv cache. Allocation indices are allocated
25
+ by TokenToKVPoolAllocator
24
26
  """
25
27
 
28
+ import abc
26
29
  import logging
27
30
  import threading
28
31
  from enum import IntEnum
@@ -89,7 +92,7 @@ class ReqToTokenPool:
89
92
  self.free_slots = list(range(self.size))
90
93
 
91
94
 
92
- class BaseTokenToKVPool:
95
+ class TokenToKVPoolAllocator:
93
96
  """A memory pool that maps a token location to its kv cache data."""
94
97
 
95
98
  def __init__(
@@ -100,11 +103,6 @@ class BaseTokenToKVPool:
100
103
  ):
101
104
  self.size = size
102
105
  self.dtype = dtype
103
- if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
104
- # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
105
- self.store_dtype = torch.uint8
106
- else:
107
- self.store_dtype = dtype
108
106
  self.device = device
109
107
 
110
108
  self.free_slots = None
@@ -148,15 +146,22 @@ class BaseTokenToKVPool:
148
146
  self.is_in_free_group = False
149
147
  self.free_group = []
150
148
 
149
+
150
+ class KVCache(abc.ABC):
151
+
152
+ @abc.abstractmethod
151
153
  def get_key_buffer(self, layer_id: int) -> torch.Tensor:
152
154
  raise NotImplementedError()
153
155
 
156
+ @abc.abstractmethod
154
157
  def get_value_buffer(self, layer_id: int) -> torch.Tensor:
155
158
  raise NotImplementedError()
156
159
 
160
+ @abc.abstractmethod
157
161
  def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
158
162
  raise NotImplementedError()
159
163
 
164
+ @abc.abstractmethod
160
165
  def set_kv_buffer(
161
166
  self,
162
167
  layer: RadixAttention,
@@ -167,7 +172,7 @@ class BaseTokenToKVPool:
167
172
  raise NotImplementedError()
168
173
 
169
174
 
170
- class MHATokenToKVPool(BaseTokenToKVPool):
175
+ class MHATokenToKVPool(KVCache):
171
176
 
172
177
  def __init__(
173
178
  self,
@@ -179,8 +184,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
179
184
  device: str,
180
185
  enable_memory_saver: bool,
181
186
  ):
182
- super().__init__(size, dtype, device)
183
-
187
+ self.size = size
188
+ self.dtype = dtype
189
+ self.device = device
190
+ if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
191
+ # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
192
+ self.store_dtype = torch.uint8
193
+ else:
194
+ self.store_dtype = dtype
184
195
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
185
196
  enable=enable_memory_saver
186
197
  )
@@ -192,7 +203,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
192
203
 
193
204
  k_size, v_size = self.get_kv_size_bytes()
194
205
  logger.info(
195
- f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
206
+ f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
196
207
  )
197
208
 
198
209
  def _create_buffers(self):
@@ -297,7 +308,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
297
308
  dst_2[loc] = src_2.to(dtype).view(store_dtype)
298
309
 
299
310
 
300
- class MLATokenToKVPool(BaseTokenToKVPool):
311
+ class MLATokenToKVPool(KVCache):
301
312
  def __init__(
302
313
  self,
303
314
  size: int,
@@ -308,8 +319,14 @@ class MLATokenToKVPool(BaseTokenToKVPool):
308
319
  device: str,
309
320
  enable_memory_saver: bool,
310
321
  ):
311
- super().__init__(size, dtype, device)
312
-
322
+ self.size = size
323
+ self.dtype = dtype
324
+ self.device = device
325
+ if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
326
+ # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
327
+ self.store_dtype = torch.uint8
328
+ else:
329
+ self.store_dtype = dtype
313
330
  self.kv_lora_rank = kv_lora_rank
314
331
 
315
332
  memory_saver_adapter = TorchMemorySaverAdapter.create(
@@ -356,7 +373,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
356
373
  self.kv_buffer[layer_id][loc] = cache_k
357
374
 
358
375
 
359
- class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
376
+ class DoubleSparseTokenToKVPool(KVCache):
360
377
  def __init__(
361
378
  self,
362
379
  size: int,
@@ -368,8 +385,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
368
385
  heavy_channel_num: int,
369
386
  enable_memory_saver: bool,
370
387
  ):
371
- super().__init__(size, dtype, device)
372
-
388
+ self.size = size
389
+ self.dtype = dtype
390
+ self.device = device
391
+ if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
392
+ # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
393
+ self.store_dtype = torch.uint8
394
+ else:
395
+ self.store_dtype = dtype
373
396
  memory_saver_adapter = TorchMemorySaverAdapter.create(
374
397
  enable=enable_memory_saver
375
398
  )
@@ -437,7 +460,7 @@ def synchronized(func):
437
460
  return wrapper
438
461
 
439
462
 
440
- class MLATokenToKVPoolHost:
463
+ class MHATokenToKVPoolHost:
441
464
 
442
465
  def __init__(
443
466
  self,
@@ -502,6 +525,9 @@ class MLATokenToKVPoolHost:
502
525
  def get_flat_data(self, indices):
503
526
  return self.kv_buffer[:, :, indices]
504
527
 
528
+ def assign_flat_data(self, indices, flat_data):
529
+ self.kv_buffer[:, :, indices] = flat_data
530
+
505
531
  @debug_timing
506
532
  def transfer(self, indices, flat_data):
507
533
  # backup prepared data from device to host