sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -100,7 +100,7 @@ class TpModelWorkerClient:
100
100
  def get_memory_pool(self):
101
101
  return (
102
102
  self.worker.model_runner.req_to_token_pool,
103
- self.worker.model_runner.token_to_kv_pool,
103
+ self.worker.model_runner.token_to_kv_pool_allocator,
104
104
  )
105
105
 
106
106
  def forward_thread_func(self):
@@ -175,7 +175,7 @@ class TpModelWorkerClient:
175
175
  logits_output.next_token_logprobs.tolist()
176
176
  )
177
177
  if logits_output.input_token_logprobs is not None:
178
- logits_output.input_token_logprobs = (
178
+ logits_output.input_token_logprobs = tuple(
179
179
  logits_output.input_token_logprobs.tolist()
180
180
  )
181
181
  next_token_ids = next_token_ids.tolist()
@@ -188,8 +188,7 @@ class TpModelWorkerClient:
188
188
  model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
189
189
  sampling_info,
190
190
  sampling_info_done=threading.Event(),
191
- scaling_penalties=sampling_info.scaling_penalties,
192
- linear_penalties=sampling_info.linear_penalties,
191
+ penalizer_orchestrator=None,
193
192
  )
194
193
 
195
194
  # A cuda stream sync here to avoid the cuda illegal memory access error.
@@ -1,29 +1,33 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
+ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
4
5
 
5
- from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
6
+ import torch
6
7
 
7
8
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
8
- from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
9
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
9
10
 
10
11
  if TYPE_CHECKING:
11
12
  from sglang.srt.managers.schedule_batch import Req
12
13
 
13
14
 
14
15
  class ChunkCacheEntry:
15
- def __init__(self, rid, value):
16
+ def __init__(self, rid: str, value: torch.Tensor):
16
17
  self.rid = rid
17
18
  self.value = value
18
19
 
19
20
 
20
21
  class ChunkCache(BasePrefixCache):
21
22
  def __init__(
22
- self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool
23
+ self,
24
+ req_to_token_pool: ReqToTokenPool,
25
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator,
23
26
  ):
24
27
  self.disable = True
25
28
  self.req_to_token_pool = req_to_token_pool
26
- self.token_to_kv_pool = token_to_kv_pool
29
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
30
+ self.entries: Dict[str, ChunkCacheEntry] = {}
27
31
 
28
32
  self.reset()
29
33
 
@@ -48,16 +52,13 @@ class ChunkCache(BasePrefixCache):
48
52
  req.req_pool_idx, :token_id_len
49
53
  ]
50
54
  self.req_to_token_pool.free(req.req_pool_idx)
51
- self.token_to_kv_pool.free(kv_indices)
55
+ self.token_to_kv_pool_allocator.free(kv_indices)
52
56
 
53
57
  if req.rid in self.entries:
54
58
  del self.entries[req.rid]
55
59
 
56
- def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
57
- if token_ids is None:
58
- token_id_len = len(req.fill_ids)
59
- else:
60
- token_id_len = len(token_ids)
60
+ def cache_unfinished_req(self, req: Req):
61
+ token_id_len = len(req.fill_ids)
61
62
 
62
63
  kv_indices = self.req_to_token_pool.req_to_token[
63
64
  req.req_pool_idx, :token_id_len
@@ -86,5 +87,11 @@ class ChunkCache(BasePrefixCache):
86
87
  def evictable_size(self):
87
88
  return 0
88
89
 
90
+ def pretty_print(self):
91
+ return ""
92
+
89
93
  def protected_size(self):
90
94
  return 0
95
+
96
+ def pretty_print(self):
97
+ return ""
@@ -0,0 +1,394 @@
1
+ import heapq
2
+ import logging
3
+ import time
4
+ from typing import List, Optional
5
+
6
+ import torch
7
+
8
+ from sglang.srt.managers.cache_controller import HiCacheController
9
+ from sglang.srt.mem_cache.memory_pool import (
10
+ MHATokenToKVPool,
11
+ MHATokenToKVPoolHost,
12
+ ReqToTokenPool,
13
+ )
14
+ from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class HiRadixCache(RadixCache):
20
+
21
+ def __init__(
22
+ self,
23
+ req_to_token_pool: ReqToTokenPool,
24
+ token_to_kv_pool: MHATokenToKVPool,
25
+ ):
26
+ self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
27
+ self.cache_controller = HiCacheController(
28
+ token_to_kv_pool, self.token_to_kv_pool_host
29
+ )
30
+
31
+ # record the nodes with ongoing write through
32
+ self.ongoing_write_through = {}
33
+ # record the node segments with ongoing load back
34
+ self.ongoing_load_back = {}
35
+ # todo: dynamically adjust the threshold
36
+ self.write_through_threshold = 1
37
+ self.load_back_threshold = 10
38
+ super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
39
+
40
+ def reset(self):
41
+ TreeNode.counter = 0
42
+ self.cache_controller.reset()
43
+ self.token_to_kv_pool_host.clear()
44
+ super().reset()
45
+
46
+ def get_height(self, node: TreeNode):
47
+ height = 0
48
+ while node != self.root_node:
49
+ node = node.parent
50
+ height += 1
51
+ return height
52
+
53
+ def write_backup(self, node: TreeNode):
54
+ host_indices = self.cache_controller.write(
55
+ device_indices=node.value,
56
+ priority=-self.get_height(node),
57
+ node_id=node.id,
58
+ )
59
+ if host_indices is None:
60
+ self.evict_host(len(node.value))
61
+ host_indices = self.cache_controller.write(
62
+ device_indices=node.value,
63
+ priority=-self.get_height(node),
64
+ node_id=node.id,
65
+ )
66
+ if host_indices is not None:
67
+ node.host_value = host_indices
68
+ self.ongoing_write_through[node.id] = node
69
+ self.inc_lock_ref(node)
70
+ else:
71
+ return None
72
+
73
+ return len(host_indices)
74
+
75
+ def inc_hit_count(self, node: TreeNode):
76
+ if self.cache_controller.write_policy != "write_through_selective":
77
+ return
78
+ node.hit_count += 1
79
+ if node.host_value is None and node.hit_count > self.write_through_threshold:
80
+ self.write_backup(node)
81
+ node.hit_count = 0
82
+
83
+ def writing_check(self):
84
+ while not self.cache_controller.ack_write_queue.empty():
85
+ try:
86
+ ack_id = self.cache_controller.ack_write_queue.get_nowait()
87
+ self.dec_lock_ref(self.ongoing_write_through[ack_id])
88
+ # clear the reference
89
+ del self.ongoing_write_through[ack_id]
90
+ except Exception:
91
+ break
92
+
93
+ def loading_check(self):
94
+ while not self.cache_controller.ack_load_queue.empty():
95
+ try:
96
+ ack_id = self.cache_controller.ack_load_queue.get_nowait()
97
+ start_node, end_node = self.ongoing_load_back[ack_id]
98
+ self.dec_lock_ref(end_node)
99
+ while end_node != start_node:
100
+ assert end_node.loading
101
+ end_node.loading = False
102
+ end_node = end_node.parent
103
+ # clear the reference
104
+ del self.ongoing_load_back[ack_id]
105
+ except Exception:
106
+ break
107
+
108
+ def evictable_size(self):
109
+ self.writing_check()
110
+ self.loading_check()
111
+ return self.evictable_size_
112
+
113
+ def evict(self, num_tokens: int, evict_callback=None):
114
+ leaves = self._collect_leaves_device()
115
+ heapq.heapify(leaves)
116
+
117
+ num_evicted = 0
118
+ pending_nodes = []
119
+ while num_evicted < num_tokens and len(leaves):
120
+ x = heapq.heappop(leaves)
121
+
122
+ if x.lock_ref > 0:
123
+ continue
124
+
125
+ if x.host_value is None:
126
+ if self.cache_controller.write_policy == "write_back":
127
+ num_evicted += self.write_backup(x)
128
+ elif self.cache_controller.write_policy == "write_through_selective":
129
+ num_evicted += self._evict_write_through_selective(x)
130
+ else:
131
+ assert (
132
+ self.cache_controller.write_policy != "write_through"
133
+ ), "write_through should be inclusive"
134
+ raise NotImplementedError
135
+ else:
136
+ num_evicted += self._evict_write_through(x)
137
+
138
+ for child in x.parent.children.values():
139
+ if child in pending_nodes:
140
+ continue
141
+ if not child.evicted:
142
+ break
143
+ else:
144
+ # all children are evicted or no children
145
+ heapq.heappush(leaves, x.parent)
146
+
147
+ if self.cache_controller.write_policy == "write_back":
148
+ # blocking till all write back complete
149
+ while len(self.ongoing_write_through) > 0:
150
+ self.writing_check()
151
+ time.sleep(0.1)
152
+
153
+ def _evict_write_through(self, node: TreeNode):
154
+ # evict a node already written to host
155
+ num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
156
+ assert num_evicted > 0
157
+ self.evictable_size_ -= num_evicted
158
+ node.value = None
159
+ return num_evicted
160
+
161
+ def _evict_write_through_selective(self, node: TreeNode):
162
+ # evict a node not initiated write to host
163
+ self.cache_controller.mem_pool_device.free(node.value)
164
+ num_evicted = len(node.value)
165
+ self._delete_leaf(node)
166
+ return num_evicted
167
+
168
+ def evict_host(self, num_tokens: int):
169
+ leaves = self._collect_leaves()
170
+ heapq.heapify(leaves)
171
+
172
+ num_evicted = 0
173
+ while num_evicted < num_tokens and len(leaves):
174
+ x = heapq.heappop(leaves)
175
+ if x == self.root_node:
176
+ break
177
+ # only evict the host value of evicted nodes
178
+ if not x.evicted:
179
+ continue
180
+ assert x.lock_ref == 0 and x.host_value is not None
181
+
182
+ assert self.cache_controller.evict_host(x.host_value) > 0
183
+ for k, v in x.parent.children.items():
184
+ if v == x:
185
+ break
186
+ del x.parent.children[k]
187
+
188
+ if len(x.parent.children) == 0 and x.parent.evicted:
189
+ heapq.heappush(leaves, x.parent)
190
+
191
+ def load_back(
192
+ self, node: TreeNode, mem_quota: Optional[int] = None
193
+ ) -> Optional[torch.Tensor]:
194
+ # todo: more loading policies
195
+
196
+ last_hit_node = node
197
+ nodes_to_load = []
198
+ while node.evicted:
199
+ assert (
200
+ node.backuped
201
+ ), "No backup available on evicted nodes, should not happen"
202
+ nodes_to_load.insert(0, node)
203
+ node = node.parent
204
+ else:
205
+ ancester_node = node
206
+
207
+ # protect the ancestor nodes from eviction
208
+ delta = self.inc_lock_ref(ancester_node)
209
+
210
+ # load it all or not at all
211
+ host_indices = torch.cat([n.host_value for n in nodes_to_load])
212
+ if len(host_indices) < self.load_back_threshold or (
213
+ len(host_indices) > mem_quota + delta if mem_quota is not None else False
214
+ ):
215
+ # skip loading back if the total size is too small or exceeding the memory quota
216
+ self.dec_lock_ref(ancester_node)
217
+ return None
218
+
219
+ device_indices = self.cache_controller.load(
220
+ host_indices=host_indices, node_id=last_hit_node.id
221
+ )
222
+ if device_indices is None:
223
+ self.evict(len(host_indices))
224
+ device_indices = self.cache_controller.load(
225
+ host_indices=host_indices, node_id=last_hit_node.id
226
+ )
227
+ self.dec_lock_ref(ancester_node)
228
+ if device_indices is None:
229
+ # no sufficient GPU memory to load back KV caches
230
+ return None
231
+
232
+ self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
233
+ offset = 0
234
+ for node in nodes_to_load:
235
+ node.value = device_indices[offset : offset + len(node.host_value)]
236
+ offset += len(node.host_value)
237
+ node.loading = True
238
+ self.evictable_size_ += len(device_indices)
239
+ self.inc_lock_ref(last_hit_node)
240
+
241
+ return device_indices
242
+
243
+ def loading_complete(self, node: TreeNode):
244
+ self.loading_check()
245
+ return node.loading == False
246
+
247
+ def init_load_back(
248
+ self,
249
+ last_node: TreeNode,
250
+ prefix_indices: torch.Tensor,
251
+ mem_quota: Optional[int] = None,
252
+ ):
253
+ assert (
254
+ len(prefix_indices) == 0 or prefix_indices.is_cuda
255
+ ), "indices of device kV caches should be on GPU"
256
+ if last_node.evicted:
257
+ loading_values = self.load_back(last_node, mem_quota)
258
+ if loading_values is not None:
259
+ prefix_indices = (
260
+ loading_values
261
+ if len(prefix_indices) == 0
262
+ else torch.cat([prefix_indices, loading_values])
263
+ )
264
+ logger.debug(
265
+ f"loading back {len(loading_values)} tokens for node {last_node.id}"
266
+ )
267
+
268
+ while last_node.evicted:
269
+ last_node = last_node.parent
270
+
271
+ return last_node, prefix_indices
272
+
273
+ def _match_prefix_helper(
274
+ self, node: TreeNode, key: List, value, last_node: TreeNode
275
+ ):
276
+ node.last_access_time = time.time()
277
+ if len(key) == 0:
278
+ return
279
+
280
+ if key[0] in node.children.keys():
281
+ child = node.children[key[0]]
282
+ prefix_len = _key_match(child.key, key)
283
+ if prefix_len < len(child.key):
284
+ new_node = self._split_node(child.key, child, prefix_len)
285
+ self.inc_hit_count(new_node)
286
+ if not new_node.evicted:
287
+ value.append(new_node.value)
288
+ last_node[0] = new_node
289
+ else:
290
+ self.inc_hit_count(child)
291
+ if not child.evicted:
292
+ value.append(child.value)
293
+ last_node[0] = child
294
+ self._match_prefix_helper(child, key[prefix_len:], value, last_node)
295
+
296
+ def _split_node(self, key, child: TreeNode, split_len: int):
297
+ # child node split into new_node -> child
298
+ new_node = TreeNode()
299
+ new_node.children = {key[split_len]: child}
300
+ new_node.parent = child.parent
301
+ new_node.lock_ref = child.lock_ref
302
+ new_node.key = child.key[:split_len]
303
+ new_node.loading = child.loading
304
+
305
+ # split value and host value if exists
306
+ if child.evicted:
307
+ new_node.value = None
308
+ else:
309
+ new_node.value = child.value[:split_len]
310
+ child.value = child.value[split_len:]
311
+ if child.host_value is not None:
312
+ new_node.host_value = child.host_value[:split_len]
313
+ child.host_value = child.host_value[split_len:]
314
+ child.parent = new_node
315
+ child.key = child.key[split_len:]
316
+ new_node.parent.children[key[0]] = new_node
317
+ return new_node
318
+
319
+ def _insert_helper(self, node: TreeNode, key: List, value):
320
+ node.last_access_time = time.time()
321
+ if len(key) == 0:
322
+ return 0
323
+
324
+ if key[0] in node.children.keys():
325
+ child = node.children[key[0]]
326
+ prefix_len = _key_match(child.key, key)
327
+
328
+ if prefix_len == len(child.key):
329
+ if child.evicted:
330
+ # change the reference if the node is evicted
331
+ # this often happens in the case of KV cache recomputation
332
+ child.value = value[:prefix_len]
333
+ self.token_to_kv_pool_host.update_synced(child.host_value)
334
+ self.evictable_size_ += len(value[:prefix_len])
335
+ return self._insert_helper(
336
+ child, key[prefix_len:], value[prefix_len:]
337
+ )
338
+ else:
339
+ self.inc_hit_count(child)
340
+ return prefix_len + self._insert_helper(
341
+ child, key[prefix_len:], value[prefix_len:]
342
+ )
343
+
344
+ # partial match, split the node
345
+ new_node = self._split_node(child.key, child, prefix_len)
346
+ if new_node.evicted:
347
+ new_node.value = value[:prefix_len]
348
+ self.token_to_kv_pool_host.update_synced(new_node.host_value)
349
+ self.evictable_size_ += len(new_node.value)
350
+ return self._insert_helper(
351
+ new_node, key[prefix_len:], value[prefix_len:]
352
+ )
353
+ else:
354
+ self.inc_hit_count(new_node)
355
+ return prefix_len + self._insert_helper(
356
+ new_node, key[prefix_len:], value[prefix_len:]
357
+ )
358
+
359
+ if len(key):
360
+ new_node = TreeNode()
361
+ new_node.parent = node
362
+ new_node.key = key
363
+ new_node.value = value
364
+ node.children[key[0]] = new_node
365
+ self.evictable_size_ += len(value)
366
+
367
+ if self.cache_controller.write_policy == "write_through":
368
+ self.write_backup(new_node)
369
+ return 0
370
+
371
+ def _collect_leaves_device(self):
372
+ def is_leaf(node):
373
+ if node.evicted:
374
+ return False
375
+ if node == self.root_node:
376
+ return False
377
+ if len(node.children) == 0:
378
+ return True
379
+ for child in node.children.values():
380
+ if not child.evicted:
381
+ return False
382
+ return True
383
+
384
+ ret_list = []
385
+ stack = [self.root_node]
386
+ while stack:
387
+ cur_node = stack.pop()
388
+ if is_leaf(cur_node):
389
+ ret_list.append(cur_node)
390
+ else:
391
+ for cur_child in cur_node.children.values():
392
+ if not cur_child.evicted:
393
+ stack.append(cur_child)
394
+ return ret_list
@@ -20,9 +20,11 @@ Memory pool.
20
20
 
21
21
  SGLang has two levels of memory pool.
22
22
  ReqToTokenPool maps a a request to its token locations.
23
- BaseTokenToKVPool maps a token location to its KV cache data.
23
+ TokenToKVPoolAllocator manages the indices to kv cache data.
24
+ KVCache actually holds the physical kv cache.
24
25
  """
25
26
 
27
+ import abc
26
28
  import logging
27
29
  import threading
28
30
  from enum import IntEnum
@@ -89,22 +91,43 @@ class ReqToTokenPool:
89
91
  self.free_slots = list(range(self.size))
90
92
 
91
93
 
92
- class BaseTokenToKVPool:
93
- """A memory pool that maps a token location to its kv cache data."""
94
+ class KVCache(abc.ABC):
95
+
96
+ @abc.abstractmethod
97
+ def get_key_buffer(self, layer_id: int) -> torch.Tensor:
98
+ raise NotImplementedError()
99
+
100
+ @abc.abstractmethod
101
+ def get_value_buffer(self, layer_id: int) -> torch.Tensor:
102
+ raise NotImplementedError()
103
+
104
+ @abc.abstractmethod
105
+ def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
106
+ raise NotImplementedError()
107
+
108
+ @abc.abstractmethod
109
+ def set_kv_buffer(
110
+ self,
111
+ layer: RadixAttention,
112
+ loc: torch.Tensor,
113
+ cache_k: torch.Tensor,
114
+ cache_v: torch.Tensor,
115
+ ) -> None:
116
+ raise NotImplementedError()
117
+
118
+
119
+ class TokenToKVPoolAllocator:
120
+ """An allocator managing the indices to kv cache data."""
94
121
 
95
122
  def __init__(
96
123
  self,
97
124
  size: int,
98
125
  dtype: torch.dtype,
99
126
  device: str,
127
+ kvcache: KVCache,
100
128
  ):
101
129
  self.size = size
102
130
  self.dtype = dtype
103
- if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
104
- # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
105
- self.store_dtype = torch.uint8
106
- else:
107
- self.store_dtype = dtype
108
131
  self.device = device
109
132
 
110
133
  self.free_slots = None
@@ -112,9 +135,14 @@ class BaseTokenToKVPool:
112
135
  self.free_group = []
113
136
  self.clear()
114
137
 
138
+ self._kvcache = kvcache
139
+
115
140
  def available_size(self):
116
141
  return len(self.free_slots)
117
142
 
143
+ def get_kvcache(self):
144
+ return self._kvcache
145
+
118
146
  def alloc(self, need_size: int):
119
147
  if need_size > len(self.free_slots):
120
148
  return None
@@ -148,26 +176,8 @@ class BaseTokenToKVPool:
148
176
  self.is_in_free_group = False
149
177
  self.free_group = []
150
178
 
151
- def get_key_buffer(self, layer_id: int) -> torch.Tensor:
152
- raise NotImplementedError()
153
-
154
- def get_value_buffer(self, layer_id: int) -> torch.Tensor:
155
- raise NotImplementedError()
156
179
 
157
- def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
158
- raise NotImplementedError()
159
-
160
- def set_kv_buffer(
161
- self,
162
- layer: RadixAttention,
163
- loc: torch.Tensor,
164
- cache_k: torch.Tensor,
165
- cache_v: torch.Tensor,
166
- ) -> None:
167
- raise NotImplementedError()
168
-
169
-
170
- class MHATokenToKVPool(BaseTokenToKVPool):
180
+ class MHATokenToKVPool(KVCache):
171
181
 
172
182
  def __init__(
173
183
  self,
@@ -179,8 +189,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
179
189
  device: str,
180
190
  enable_memory_saver: bool,
181
191
  ):
182
- super().__init__(size, dtype, device)
183
-
192
+ self.size = size
193
+ self.dtype = dtype
194
+ self.device = device
195
+ if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
196
+ # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
197
+ self.store_dtype = torch.uint8
198
+ else:
199
+ self.store_dtype = dtype
184
200
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
185
201
  enable=enable_memory_saver
186
202
  )
@@ -192,7 +208,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
192
208
 
193
209
  k_size, v_size = self.get_kv_size_bytes()
194
210
  logger.info(
195
- f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
211
+ f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
196
212
  )
197
213
 
198
214
  def _create_buffers(self):
@@ -297,7 +313,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
297
313
  dst_2[loc] = src_2.to(dtype).view(store_dtype)
298
314
 
299
315
 
300
- class MLATokenToKVPool(BaseTokenToKVPool):
316
+ class MLATokenToKVPool(KVCache):
301
317
  def __init__(
302
318
  self,
303
319
  size: int,
@@ -308,8 +324,14 @@ class MLATokenToKVPool(BaseTokenToKVPool):
308
324
  device: str,
309
325
  enable_memory_saver: bool,
310
326
  ):
311
- super().__init__(size, dtype, device)
312
-
327
+ self.size = size
328
+ self.dtype = dtype
329
+ self.device = device
330
+ if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
331
+ # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
332
+ self.store_dtype = torch.uint8
333
+ else:
334
+ self.store_dtype = dtype
313
335
  self.kv_lora_rank = kv_lora_rank
314
336
 
315
337
  memory_saver_adapter = TorchMemorySaverAdapter.create(
@@ -356,7 +378,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
356
378
  self.kv_buffer[layer_id][loc] = cache_k
357
379
 
358
380
 
359
- class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
381
+ class DoubleSparseTokenToKVPool(KVCache):
360
382
  def __init__(
361
383
  self,
362
384
  size: int,
@@ -368,8 +390,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
368
390
  heavy_channel_num: int,
369
391
  enable_memory_saver: bool,
370
392
  ):
371
- super().__init__(size, dtype, device)
372
-
393
+ self.size = size
394
+ self.dtype = dtype
395
+ self.device = device
396
+ if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
397
+ # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
398
+ self.store_dtype = torch.uint8
399
+ else:
400
+ self.store_dtype = dtype
373
401
  memory_saver_adapter = TorchMemorySaverAdapter.create(
374
402
  enable=enable_memory_saver
375
403
  )
@@ -437,7 +465,7 @@ def synchronized(func):
437
465
  return wrapper
438
466
 
439
467
 
440
- class MLATokenToKVPoolHost:
468
+ class MHATokenToKVPoolHost:
441
469
 
442
470
  def __init__(
443
471
  self,
@@ -502,6 +530,9 @@ class MLATokenToKVPoolHost:
502
530
  def get_flat_data(self, indices):
503
531
  return self.kv_buffer[:, :, indices]
504
532
 
533
+ def assign_flat_data(self, indices, flat_data):
534
+ self.kv_buffer[:, :, indices] = flat_data
535
+
505
536
  @debug_timing
506
537
  def transfer(self, indices, flat_data):
507
538
  # backup prepared data from device to host