sglang 0.4.3.post3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/hf_transformers_utils.py +16 -1
  14. sglang/srt/layers/attention/flashinfer_backend.py +95 -49
  15. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  16. sglang/srt/layers/attention/triton_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  18. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  19. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  20. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  21. sglang/srt/layers/attention/vision.py +43 -62
  22. sglang/srt/layers/linear.py +1 -1
  23. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  24. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  33. sglang/srt/layers/parameter.py +10 -0
  34. sglang/srt/layers/quantization/__init__.py +90 -68
  35. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  36. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/fp8.py +174 -106
  63. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  64. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  65. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  66. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  67. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  68. sglang/srt/layers/rotary_embedding.py +5 -3
  69. sglang/srt/layers/sampler.py +29 -35
  70. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  71. sglang/srt/lora/backend/__init__.py +9 -12
  72. sglang/srt/managers/cache_controller.py +72 -8
  73. sglang/srt/managers/image_processor.py +37 -631
  74. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  75. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  76. sglang/srt/managers/image_processors/llava.py +152 -0
  77. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  78. sglang/srt/managers/image_processors/mlama.py +60 -0
  79. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  80. sglang/srt/managers/io_struct.py +33 -15
  81. sglang/srt/managers/multi_modality_padding.py +134 -0
  82. sglang/srt/managers/schedule_batch.py +212 -117
  83. sglang/srt/managers/schedule_policy.py +40 -8
  84. sglang/srt/managers/scheduler.py +258 -782
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
  86. sglang/srt/managers/tokenizer_manager.py +7 -6
  87. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  88. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  89. sglang/srt/mem_cache/chunk_cache.py +12 -44
  90. sglang/srt/mem_cache/hiradix_cache.py +63 -34
  91. sglang/srt/mem_cache/memory_pool.py +112 -46
  92. sglang/srt/mem_cache/paged_allocator.py +283 -0
  93. sglang/srt/mem_cache/radix_cache.py +117 -36
  94. sglang/srt/metrics/collector.py +8 -0
  95. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  96. sglang/srt/model_executor/forward_batch_info.py +12 -8
  97. sglang/srt/model_executor/model_runner.py +153 -134
  98. sglang/srt/model_loader/loader.py +2 -1
  99. sglang/srt/model_loader/weight_utils.py +1 -1
  100. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  101. sglang/srt/models/deepseek_nextn.py +23 -3
  102. sglang/srt/models/deepseek_v2.py +25 -19
  103. sglang/srt/models/minicpmv.py +28 -89
  104. sglang/srt/models/mllama.py +1 -1
  105. sglang/srt/models/qwen2.py +0 -1
  106. sglang/srt/models/qwen2_5_vl.py +25 -50
  107. sglang/srt/models/qwen2_vl.py +33 -49
  108. sglang/srt/openai_api/adapter.py +37 -15
  109. sglang/srt/openai_api/protocol.py +8 -1
  110. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  111. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  112. sglang/srt/server_args.py +19 -20
  113. sglang/srt/speculative/build_eagle_tree.py +6 -1
  114. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
  115. sglang/srt/speculative/eagle_utils.py +2 -1
  116. sglang/srt/speculative/eagle_worker.py +109 -38
  117. sglang/srt/utils.py +104 -9
  118. sglang/test/runners.py +104 -10
  119. sglang/test/test_block_fp8.py +106 -16
  120. sglang/test/test_custom_ops.py +88 -0
  121. sglang/test/test_utils.py +20 -4
  122. sglang/utils.py +0 -4
  123. sglang/version.py +1 -1
  124. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -9
  125. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/RECORD +128 -83
  126. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
  127. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Callable, List, Tuple
2
+ from typing import Any, List, Tuple
3
3
 
4
4
 
5
5
  class BasePrefixCache(ABC):
@@ -26,24 +26,22 @@ class BasePrefixCache(ABC):
26
26
  pass
27
27
 
28
28
  @abstractmethod
29
- def evict(self, num_tokens: int, evict_callback: Callable):
29
+ def evict(self, num_tokens: int):
30
30
  pass
31
31
 
32
32
  @abstractmethod
33
- def inc_lock_ref(self, node):
33
+ def inc_lock_ref(self, node: Any):
34
34
  pass
35
35
 
36
36
  @abstractmethod
37
- def dec_lock_ref(self, node):
37
+ def dec_lock_ref(self, node: Any):
38
38
  pass
39
39
 
40
- @abstractmethod
41
40
  def evictable_size(self):
42
- pass
41
+ return 0
43
42
 
44
- @abstractmethod
45
43
  def protected_size(self):
46
- raise NotImplementedError()
44
+ return 0
47
45
 
48
46
  def total_size(self):
49
47
  raise NotImplementedError()
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
- from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
4
+
5
+ from typing import TYPE_CHECKING, Any, Callable, List, Tuple
5
6
 
6
7
  import torch
7
8
 
@@ -24,73 +25,40 @@ class ChunkCache(BasePrefixCache):
24
25
  req_to_token_pool: ReqToTokenPool,
25
26
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
26
27
  ):
27
- self.disable = True
28
28
  self.req_to_token_pool = req_to_token_pool
29
29
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
30
- self.entries: Dict[str, ChunkCacheEntry] = {}
31
-
32
- self.reset()
33
30
 
34
31
  def reset(self):
35
- self.entries = {}
36
-
37
- def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
38
- if rid not in self.entries:
39
- return [], None
40
-
41
- entry = self.entries[rid]
42
- max_prefix_len = len(key)
43
- return entry.value[:max_prefix_len], entry
32
+ pass
44
33
 
45
- def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
46
- if token_ids is None:
47
- token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
48
- else:
49
- token_id_len = len(token_ids)
34
+ def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
35
+ return [], None
50
36
 
37
+ def cache_finished_req(self, req: Req):
51
38
  kv_indices = self.req_to_token_pool.req_to_token[
52
- req.req_pool_idx, :token_id_len
39
+ req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
53
40
  ]
54
41
  self.req_to_token_pool.free(req.req_pool_idx)
55
42
  self.token_to_kv_pool_allocator.free(kv_indices)
56
43
 
57
- if req.rid in self.entries:
58
- del self.entries[req.rid]
59
-
60
44
  def cache_unfinished_req(self, req: Req):
61
- token_id_len = len(req.fill_ids)
62
-
63
45
  kv_indices = self.req_to_token_pool.req_to_token[
64
- req.req_pool_idx, :token_id_len
46
+ req.req_pool_idx, : len(req.fill_ids)
65
47
  ]
66
48
 
67
- if req.rid not in self.entries:
68
- self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
69
-
70
- entry = self.entries[req.rid]
71
- entry.value = kv_indices
49
+ # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
72
50
  req.prefix_indices = kv_indices
73
- req.last_node = entry
74
51
 
75
52
  def insert(self):
76
53
  raise NotImplementedError()
77
54
 
78
- def evict(self, num_tokens: int, evict_callback: Callable):
55
+ def evict(self, num_tokens: int):
79
56
  pass
80
57
 
81
- def inc_lock_ref(self, node):
58
+ def inc_lock_ref(self, node: Any):
82
59
  return 0
83
60
 
84
- def dec_lock_ref(self, node):
85
- return 0
86
-
87
- def evictable_size(self):
88
- return 0
89
-
90
- def pretty_print(self):
91
- return ""
92
-
93
- def protected_size(self):
61
+ def dec_lock_ref(self, node: Any):
94
62
  return 0
95
63
 
96
64
  def pretty_print(self):
@@ -1,5 +1,6 @@
1
1
  import heapq
2
2
  import logging
3
+ import threading
3
4
  import time
4
5
  from typing import List, Optional
5
6
 
@@ -7,11 +8,12 @@ import torch
7
8
 
8
9
  from sglang.srt.managers.cache_controller import HiCacheController
9
10
  from sglang.srt.mem_cache.memory_pool import (
10
- MHATokenToKVPool,
11
11
  MHATokenToKVPoolHost,
12
12
  ReqToTokenPool,
13
+ TokenToKVPoolAllocator,
13
14
  )
14
- from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
15
+ from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
16
+ from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
15
17
 
16
18
  logger = logging.getLogger(__name__)
17
19
 
@@ -21,11 +23,19 @@ class HiRadixCache(RadixCache):
21
23
  def __init__(
22
24
  self,
23
25
  req_to_token_pool: ReqToTokenPool,
24
- token_to_kv_pool: MHATokenToKVPool,
26
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator,
27
+ tp_cache_group: torch.distributed.ProcessGroup,
25
28
  ):
26
- self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
29
+ self.token_to_kv_pool_host = MHATokenToKVPoolHost(
30
+ token_to_kv_pool_allocator.get_kvcache()
31
+ )
32
+ self.tp_group = tp_cache_group
33
+
34
+ self.load_cache_event = threading.Event()
27
35
  self.cache_controller = HiCacheController(
28
- token_to_kv_pool, self.token_to_kv_pool_host
36
+ token_to_kv_pool_allocator,
37
+ self.token_to_kv_pool_host,
38
+ load_cache_event=self.load_cache_event,
29
39
  )
30
40
 
31
41
  # record the nodes with ongoing write through
@@ -35,7 +45,7 @@ class HiRadixCache(RadixCache):
35
45
  # todo: dynamically adjust the threshold
36
46
  self.write_through_threshold = 1
37
47
  self.load_back_threshold = 10
38
- super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
48
+ super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False)
39
49
 
40
50
  def reset(self):
41
51
  TreeNode.counter = 0
@@ -53,14 +63,12 @@ class HiRadixCache(RadixCache):
53
63
  def write_backup(self, node: TreeNode):
54
64
  host_indices = self.cache_controller.write(
55
65
  device_indices=node.value,
56
- priority=-self.get_height(node),
57
66
  node_id=node.id,
58
67
  )
59
68
  if host_indices is None:
60
69
  self.evict_host(len(node.value))
61
70
  host_indices = self.cache_controller.write(
62
71
  device_indices=node.value,
63
- priority=-self.get_height(node),
64
72
  node_id=node.id,
65
73
  )
66
74
  if host_indices is not None:
@@ -81,14 +89,20 @@ class HiRadixCache(RadixCache):
81
89
  node.hit_count = 0
82
90
 
83
91
  def writing_check(self):
84
- while not self.cache_controller.ack_write_queue.empty():
85
- try:
86
- ack_id = self.cache_controller.ack_write_queue.get_nowait()
87
- self.dec_lock_ref(self.ongoing_write_through[ack_id])
88
- # clear the reference
89
- del self.ongoing_write_through[ack_id]
90
- except Exception:
91
- break
92
+ queue_size = torch.tensor(
93
+ self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
94
+ )
95
+ if torch.distributed.get_world_size(group=self.tp_group) > 1:
96
+ # synchrnoize TP workers to make the same update to radix cache
97
+ torch.distributed.all_reduce(
98
+ queue_size,
99
+ op=torch.distributed.ReduceOp.MIN,
100
+ group=self.tp_group,
101
+ )
102
+ for _ in range(queue_size.item()):
103
+ ack_id = self.cache_controller.ack_write_queue.get()
104
+ self.dec_lock_ref(self.ongoing_write_through[ack_id])
105
+ del self.ongoing_write_through[ack_id]
92
106
 
93
107
  def loading_check(self):
94
108
  while not self.cache_controller.ack_load_queue.empty():
@@ -106,11 +120,9 @@ class HiRadixCache(RadixCache):
106
120
  break
107
121
 
108
122
  def evictable_size(self):
109
- self.writing_check()
110
- self.loading_check()
111
123
  return self.evictable_size_
112
124
 
113
- def evict(self, num_tokens: int, evict_callback=None):
125
+ def evict(self, num_tokens: int):
114
126
  leaves = self._collect_leaves_device()
115
127
  heapq.heapify(leaves)
116
128
 
@@ -160,7 +172,7 @@ class HiRadixCache(RadixCache):
160
172
 
161
173
  def _evict_write_through_selective(self, node: TreeNode):
162
174
  # evict a node not initiated write to host
163
- self.cache_controller.mem_pool_device.free(node.value)
175
+ self.cache_controller.mem_pool_device_allocator.free(node.value)
164
176
  num_evicted = len(node.value)
165
177
  self._delete_leaf(node)
166
178
  return num_evicted
@@ -240,10 +252,6 @@ class HiRadixCache(RadixCache):
240
252
 
241
253
  return device_indices
242
254
 
243
- def loading_complete(self, node: TreeNode):
244
- self.loading_check()
245
- return node.loading == False
246
-
247
255
  def init_load_back(
248
256
  self,
249
257
  last_node: TreeNode,
@@ -270,28 +278,49 @@ class HiRadixCache(RadixCache):
270
278
 
271
279
  return last_node, prefix_indices
272
280
 
273
- def _match_prefix_helper(
274
- self, node: TreeNode, key: List, value, last_node: TreeNode
275
- ):
276
- node.last_access_time = time.time()
277
- if len(key) == 0:
278
- return
281
+ def read_to_load_cache(self):
282
+ self.load_cache_event.set()
279
283
 
280
- if key[0] in node.children.keys():
284
+ def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
285
+ if self.disable:
286
+ return [], self.root_node
287
+
288
+ value, last_node = self._match_prefix_helper(self.root_node, key)
289
+ if value:
290
+ value = torch.concat(value)
291
+ else:
292
+ value = torch.tensor([], dtype=torch.int32)
293
+
294
+ last_node_global = last_node
295
+ while last_node.evicted:
296
+ last_node = last_node.parent
297
+
298
+ if include_evicted:
299
+ return value, last_node, last_node_global
300
+ else:
301
+ return value, last_node
302
+
303
+ def _match_prefix_helper(self, node: TreeNode, key: List):
304
+ node.last_access_time = time.time()
305
+ value = []
306
+ while len(key) > 0 and key[0] in node.children.keys():
281
307
  child = node.children[key[0]]
308
+ child.last_access_time = time.time()
282
309
  prefix_len = _key_match(child.key, key)
283
310
  if prefix_len < len(child.key):
284
311
  new_node = self._split_node(child.key, child, prefix_len)
285
312
  self.inc_hit_count(new_node)
286
313
  if not new_node.evicted:
287
314
  value.append(new_node.value)
288
- last_node[0] = new_node
315
+ node = new_node
316
+ break
289
317
  else:
290
318
  self.inc_hit_count(child)
291
319
  if not child.evicted:
292
320
  value.append(child.value)
293
- last_node[0] = child
294
- self._match_prefix_helper(child, key[prefix_len:], value, last_node)
321
+ node = child
322
+ key = key[prefix_len:]
323
+ return value, node
295
324
 
296
325
  def _split_node(self, key, child: TreeNode, split_len: int):
297
326
  # child node split into new_node -> child
@@ -20,9 +20,8 @@ Memory pool.
20
20
 
21
21
  SGLang has two levels of memory pool.
22
22
  ReqToTokenPool maps a a request to its token locations.
23
- TokenToKVPoolAllocator maps a token location to its KV cache data.
24
- KVCache actually holds the physical kv cache. Allocation indices are allocated
25
- by TokenToKVPoolAllocator
23
+ TokenToKVPoolAllocator manages the indices to kv cache data.
24
+ KVCache actually holds the physical kv cache.
26
25
  """
27
26
 
28
27
  import abc
@@ -92,42 +91,73 @@ class ReqToTokenPool:
92
91
  self.free_slots = list(range(self.size))
93
92
 
94
93
 
94
+ class KVCache(abc.ABC):
95
+
96
+ @abc.abstractmethod
97
+ def get_key_buffer(self, layer_id: int) -> torch.Tensor:
98
+ raise NotImplementedError()
99
+
100
+ @abc.abstractmethod
101
+ def get_value_buffer(self, layer_id: int) -> torch.Tensor:
102
+ raise NotImplementedError()
103
+
104
+ @abc.abstractmethod
105
+ def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
106
+ raise NotImplementedError()
107
+
108
+ @abc.abstractmethod
109
+ def set_kv_buffer(
110
+ self,
111
+ layer: RadixAttention,
112
+ loc: torch.Tensor,
113
+ cache_k: torch.Tensor,
114
+ cache_v: torch.Tensor,
115
+ ) -> None:
116
+ raise NotImplementedError()
117
+
118
+
95
119
  class TokenToKVPoolAllocator:
96
- """A memory pool that maps a token location to its kv cache data."""
120
+ """An allocator managing the indices to kv cache data."""
97
121
 
98
122
  def __init__(
99
123
  self,
100
124
  size: int,
101
125
  dtype: torch.dtype,
102
126
  device: str,
127
+ kvcache: KVCache,
103
128
  ):
104
129
  self.size = size
105
130
  self.dtype = dtype
106
131
  self.device = device
132
+ self.page_size = 1
107
133
 
108
134
  self.free_slots = None
109
135
  self.is_not_in_free_group = True
110
136
  self.free_group = []
111
137
  self.clear()
112
138
 
139
+ self._kvcache = kvcache
140
+
113
141
  def available_size(self):
114
142
  return len(self.free_slots)
115
143
 
144
+ def get_kvcache(self):
145
+ return self._kvcache
146
+
116
147
  def alloc(self, need_size: int):
117
148
  if need_size > len(self.free_slots):
118
149
  return None
119
150
 
120
151
  select_index = self.free_slots[:need_size]
121
152
  self.free_slots = self.free_slots[need_size:]
122
-
123
- return select_index.to(self.device, non_blocking=True)
153
+ return select_index
124
154
 
125
155
  def free(self, free_index: torch.Tensor):
126
156
  if free_index.numel() == 0:
127
157
  return
128
158
 
129
159
  if self.is_not_in_free_group:
130
- self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
160
+ self.free_slots = torch.concat((self.free_slots, free_index))
131
161
  else:
132
162
  self.free_group.append(free_index)
133
163
 
@@ -142,41 +172,19 @@ class TokenToKVPoolAllocator:
142
172
 
143
173
  def clear(self):
144
174
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
145
- self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32)
175
+ self.free_slots = torch.arange(
176
+ 1, self.size + 1, dtype=torch.int64, device=self.device
177
+ )
146
178
  self.is_in_free_group = False
147
179
  self.free_group = []
148
180
 
149
181
 
150
- class KVCache(abc.ABC):
151
-
152
- @abc.abstractmethod
153
- def get_key_buffer(self, layer_id: int) -> torch.Tensor:
154
- raise NotImplementedError()
155
-
156
- @abc.abstractmethod
157
- def get_value_buffer(self, layer_id: int) -> torch.Tensor:
158
- raise NotImplementedError()
159
-
160
- @abc.abstractmethod
161
- def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
162
- raise NotImplementedError()
163
-
164
- @abc.abstractmethod
165
- def set_kv_buffer(
166
- self,
167
- layer: RadixAttention,
168
- loc: torch.Tensor,
169
- cache_k: torch.Tensor,
170
- cache_v: torch.Tensor,
171
- ) -> None:
172
- raise NotImplementedError()
173
-
174
-
175
182
  class MHATokenToKVPool(KVCache):
176
183
 
177
184
  def __init__(
178
185
  self,
179
186
  size: int,
187
+ page_size: int,
180
188
  dtype: torch.dtype,
181
189
  head_num: int,
182
190
  head_dim: int,
@@ -185,6 +193,7 @@ class MHATokenToKVPool(KVCache):
185
193
  enable_memory_saver: bool,
186
194
  ):
187
195
  self.size = size
196
+ self.page_size = page_size
188
197
  self.dtype = dtype
189
198
  self.device = device
190
199
  if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
@@ -201,6 +210,10 @@ class MHATokenToKVPool(KVCache):
201
210
  self.layer_num = layer_num
202
211
  self._create_buffers()
203
212
 
213
+ self.layer_transfer_counter = None
214
+ self.capture_mode = False
215
+ self.alt_stream = torch.cuda.Stream()
216
+
204
217
  k_size, v_size = self.get_kv_size_bytes()
205
218
  logger.info(
206
219
  f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
@@ -211,16 +224,16 @@ class MHATokenToKVPool(KVCache):
211
224
  # [size, head_num, head_dim] for each layer
212
225
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
213
226
  self.k_buffer = [
214
- torch.empty(
215
- (self.size + 1, self.head_num, self.head_dim),
227
+ torch.zeros(
228
+ (self.size + self.page_size, self.head_num, self.head_dim),
216
229
  dtype=self.store_dtype,
217
230
  device=self.device,
218
231
  )
219
232
  for _ in range(self.layer_num)
220
233
  ]
221
234
  self.v_buffer = [
222
- torch.empty(
223
- (self.size + 1, self.head_num, self.head_dim),
235
+ torch.zeros(
236
+ (self.size + self.page_size, self.head_num, self.head_dim),
224
237
  dtype=self.store_dtype,
225
238
  device=self.device,
226
239
  )
@@ -262,12 +275,28 @@ class MHATokenToKVPool(KVCache):
262
275
  self.k_buffer[i][indices] = k_data[i]
263
276
  self.v_buffer[i][indices] = v_data[i]
264
277
 
278
+ def register_layer_transfer_counter(self, layer_transfer_counter):
279
+ self.layer_transfer_counter = layer_transfer_counter
280
+
281
+ def transfer_per_layer(self, indices, flat_data, layer_id):
282
+ # transfer prepared data from host to device
283
+ flat_data = flat_data.to(device=self.device, non_blocking=False)
284
+ k_data, v_data = flat_data[0], flat_data[1]
285
+ self.k_buffer[layer_id][indices] = k_data
286
+ self.v_buffer[layer_id][indices] = v_data
287
+
265
288
  def get_key_buffer(self, layer_id: int):
289
+ if self.layer_transfer_counter is not None:
290
+ self.layer_transfer_counter.wait_until(layer_id)
291
+
266
292
  if self.store_dtype != self.dtype:
267
293
  return self.k_buffer[layer_id].view(self.dtype)
268
294
  return self.k_buffer[layer_id]
269
295
 
270
296
  def get_value_buffer(self, layer_id: int):
297
+ if self.layer_transfer_counter is not None:
298
+ self.layer_transfer_counter.wait_until(layer_id)
299
+
271
300
  if self.store_dtype != self.dtype:
272
301
  return self.v_buffer[layer_id].view(self.dtype)
273
302
  return self.v_buffer[layer_id]
@@ -292,14 +321,44 @@ class MHATokenToKVPool(KVCache):
292
321
  cache_v.div_(v_scale)
293
322
  cache_k = cache_k.to(self.dtype)
294
323
  cache_v = cache_v.to(self.dtype)
324
+
295
325
  if self.store_dtype != self.dtype:
296
- self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
297
- self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
326
+ cache_k = cache_k.view(self.store_dtype)
327
+ cache_v = cache_v.view(self.store_dtype)
328
+
329
+ if self.capture_mode:
330
+ self.alt_stream.wait_stream(torch.cuda.current_stream())
331
+ with torch.cuda.stream(self.alt_stream):
332
+ self.k_buffer[layer_id][loc] = cache_k
333
+ self.v_buffer[layer_id][loc] = cache_v
334
+ torch.cuda.current_stream().wait_stream(self.alt_stream)
298
335
  else:
299
336
  self.k_buffer[layer_id][loc] = cache_k
300
337
  self.v_buffer[layer_id][loc] = cache_v
301
338
 
302
339
 
340
+ @torch.compile
341
+ def fused_downcast(
342
+ cache_k: torch.Tensor,
343
+ cache_v: torch.Tensor,
344
+ k_scale: torch.Tensor,
345
+ v_scale: torch.Tensor,
346
+ dtype: torch.dtype,
347
+ store_dtype: torch.dtype,
348
+ max_fp8: float,
349
+ min_fp8: float,
350
+ ):
351
+ cache_k = cache_k / k_scale
352
+ cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
353
+ cache_v = cache_v / v_scale
354
+ cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
355
+ cache_k = cache_k.to(dtype)
356
+ cache_v = cache_v.to(dtype)
357
+ cache_k = cache_k.view(store_dtype)
358
+ cache_v = cache_v.view(store_dtype)
359
+ return cache_k, cache_v
360
+
361
+
303
362
  # This compiled version is slower in the unit test
304
363
  # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
305
364
  @torch.compile(dynamic=True, backend=get_compiler_backend())
@@ -312,6 +371,7 @@ class MLATokenToKVPool(KVCache):
312
371
  def __init__(
313
372
  self,
314
373
  size: int,
374
+ page_size: int,
315
375
  dtype: torch.dtype,
316
376
  kv_lora_rank: int,
317
377
  qk_rope_head_dim: int,
@@ -336,8 +396,8 @@ class MLATokenToKVPool(KVCache):
336
396
  with memory_saver_adapter.region():
337
397
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
338
398
  self.kv_buffer = [
339
- torch.empty(
340
- (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
399
+ torch.zeros(
400
+ (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
341
401
  dtype=self.store_dtype,
342
402
  device=device,
343
403
  )
@@ -377,6 +437,7 @@ class DoubleSparseTokenToKVPool(KVCache):
377
437
  def __init__(
378
438
  self,
379
439
  size: int,
440
+ page_size: int,
380
441
  dtype: torch.dtype,
381
442
  head_num: int,
382
443
  head_dim: int,
@@ -386,6 +447,7 @@ class DoubleSparseTokenToKVPool(KVCache):
386
447
  enable_memory_saver: bool,
387
448
  ):
388
449
  self.size = size
450
+ self.page_size = page_size
389
451
  self.dtype = dtype
390
452
  self.device = device
391
453
  if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
@@ -400,17 +462,21 @@ class DoubleSparseTokenToKVPool(KVCache):
400
462
  with memory_saver_adapter.region():
401
463
  # [size, head_num, head_dim] for each layer
402
464
  self.k_buffer = [
403
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
465
+ torch.zeros(
466
+ (size + page_size, head_num, head_dim), dtype=dtype, device=device
467
+ )
404
468
  for _ in range(layer_num)
405
469
  ]
406
470
  self.v_buffer = [
407
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
471
+ torch.zeros(
472
+ (size + page_size, head_num, head_dim), dtype=dtype, device=device
473
+ )
408
474
  for _ in range(layer_num)
409
475
  ]
410
476
 
411
477
  # [size, head_num, heavy_channel_num] for each layer
412
478
  self.label_buffer = [
413
- torch.empty(
479
+ torch.zeros(
414
480
  (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
415
481
  )
416
482
  for _ in range(layer_num)
@@ -465,7 +531,7 @@ class MHATokenToKVPoolHost:
465
531
  def __init__(
466
532
  self,
467
533
  device_pool: MHATokenToKVPool,
468
- host_to_device_ratio: float = 2.0,
534
+ host_to_device_ratio: float = 3.0,
469
535
  pin_memory: bool = False, # no need to use pin memory with the double buffering
470
536
  device: str = "cpu",
471
537
  ):
@@ -505,7 +571,7 @@ class MHATokenToKVPoolHost:
505
571
  f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
506
572
  )
507
573
 
508
- self.kv_buffer = torch.empty(
574
+ self.kv_buffer = torch.zeros(
509
575
  (2, self.layer_num, self.size, self.head_num, self.head_dim),
510
576
  dtype=self.dtype,
511
577
  device=self.device,