sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
- from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
4
+
5
+ from typing import TYPE_CHECKING, Any, Callable, List, Tuple
5
6
 
6
7
  import torch
7
8
 
@@ -24,73 +25,40 @@ class ChunkCache(BasePrefixCache):
24
25
  req_to_token_pool: ReqToTokenPool,
25
26
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
26
27
  ):
27
- self.disable = True
28
28
  self.req_to_token_pool = req_to_token_pool
29
29
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
30
- self.entries: Dict[str, ChunkCacheEntry] = {}
31
-
32
- self.reset()
33
30
 
34
31
  def reset(self):
35
- self.entries = {}
36
-
37
- def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
38
- if rid not in self.entries:
39
- return [], None
40
-
41
- entry = self.entries[rid]
42
- max_prefix_len = len(key)
43
- return entry.value[:max_prefix_len], entry
32
+ pass
44
33
 
45
- def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
46
- if token_ids is None:
47
- token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
48
- else:
49
- token_id_len = len(token_ids)
34
+ def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
35
+ return [], None
50
36
 
37
+ def cache_finished_req(self, req: Req):
51
38
  kv_indices = self.req_to_token_pool.req_to_token[
52
- req.req_pool_idx, :token_id_len
39
+ req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
53
40
  ]
54
41
  self.req_to_token_pool.free(req.req_pool_idx)
55
42
  self.token_to_kv_pool_allocator.free(kv_indices)
56
43
 
57
- if req.rid in self.entries:
58
- del self.entries[req.rid]
59
-
60
44
  def cache_unfinished_req(self, req: Req):
61
- token_id_len = len(req.fill_ids)
62
-
63
45
  kv_indices = self.req_to_token_pool.req_to_token[
64
- req.req_pool_idx, :token_id_len
46
+ req.req_pool_idx, : len(req.fill_ids)
65
47
  ]
66
48
 
67
- if req.rid not in self.entries:
68
- self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
69
-
70
- entry = self.entries[req.rid]
71
- entry.value = kv_indices
49
+ # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
72
50
  req.prefix_indices = kv_indices
73
- req.last_node = entry
74
51
 
75
52
  def insert(self):
76
53
  raise NotImplementedError()
77
54
 
78
- def evict(self, num_tokens: int, evict_callback: Callable):
55
+ def evict(self, num_tokens: int):
79
56
  pass
80
57
 
81
- def inc_lock_ref(self, node):
58
+ def inc_lock_ref(self, node: Any):
82
59
  return 0
83
60
 
84
- def dec_lock_ref(self, node):
85
- return 0
86
-
87
- def evictable_size(self):
88
- return 0
89
-
90
- def pretty_print(self):
91
- return ""
92
-
93
- def protected_size(self):
61
+ def dec_lock_ref(self, node: Any):
94
62
  return 0
95
63
 
96
64
  def pretty_print(self):
@@ -1,5 +1,6 @@
1
1
  import heapq
2
2
  import logging
3
+ import threading
3
4
  import time
4
5
  from typing import List, Optional
5
6
 
@@ -7,11 +8,12 @@ import torch
7
8
 
8
9
  from sglang.srt.managers.cache_controller import HiCacheController
9
10
  from sglang.srt.mem_cache.memory_pool import (
10
- MHATokenToKVPool,
11
11
  MHATokenToKVPoolHost,
12
12
  ReqToTokenPool,
13
+ TokenToKVPoolAllocator,
13
14
  )
14
- from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
15
+ from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
16
+ from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
15
17
 
16
18
  logger = logging.getLogger(__name__)
17
19
 
@@ -21,11 +23,25 @@ class HiRadixCache(RadixCache):
21
23
  def __init__(
22
24
  self,
23
25
  req_to_token_pool: ReqToTokenPool,
24
- token_to_kv_pool: MHATokenToKVPool,
26
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator,
27
+ tp_cache_group: torch.distributed.ProcessGroup,
28
+ page_size: int,
25
29
  ):
26
- self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
30
+ if page_size != 1:
31
+ raise ValueError(
32
+ "Page size larger than 1 is not yet supported in HiRadixCache."
33
+ )
34
+ self.token_to_kv_pool_host = MHATokenToKVPoolHost(
35
+ token_to_kv_pool_allocator.get_kvcache()
36
+ )
37
+ self.tp_group = tp_cache_group
38
+ self.page_size = page_size
39
+
40
+ self.load_cache_event = threading.Event()
27
41
  self.cache_controller = HiCacheController(
28
- token_to_kv_pool, self.token_to_kv_pool_host
42
+ token_to_kv_pool_allocator,
43
+ self.token_to_kv_pool_host,
44
+ load_cache_event=self.load_cache_event,
29
45
  )
30
46
 
31
47
  # record the nodes with ongoing write through
@@ -35,7 +51,9 @@ class HiRadixCache(RadixCache):
35
51
  # todo: dynamically adjust the threshold
36
52
  self.write_through_threshold = 1
37
53
  self.load_back_threshold = 10
38
- super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
54
+ super().__init__(
55
+ req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
56
+ )
39
57
 
40
58
  def reset(self):
41
59
  TreeNode.counter = 0
@@ -53,14 +71,12 @@ class HiRadixCache(RadixCache):
53
71
  def write_backup(self, node: TreeNode):
54
72
  host_indices = self.cache_controller.write(
55
73
  device_indices=node.value,
56
- priority=-self.get_height(node),
57
74
  node_id=node.id,
58
75
  )
59
76
  if host_indices is None:
60
77
  self.evict_host(len(node.value))
61
78
  host_indices = self.cache_controller.write(
62
79
  device_indices=node.value,
63
- priority=-self.get_height(node),
64
80
  node_id=node.id,
65
81
  )
66
82
  if host_indices is not None:
@@ -81,14 +97,20 @@ class HiRadixCache(RadixCache):
81
97
  node.hit_count = 0
82
98
 
83
99
  def writing_check(self):
84
- while not self.cache_controller.ack_write_queue.empty():
85
- try:
86
- ack_id = self.cache_controller.ack_write_queue.get_nowait()
87
- self.dec_lock_ref(self.ongoing_write_through[ack_id])
88
- # clear the reference
89
- del self.ongoing_write_through[ack_id]
90
- except Exception:
91
- break
100
+ queue_size = torch.tensor(
101
+ self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
102
+ )
103
+ if torch.distributed.get_world_size(group=self.tp_group) > 1:
104
+ # synchrnoize TP workers to make the same update to radix cache
105
+ torch.distributed.all_reduce(
106
+ queue_size,
107
+ op=torch.distributed.ReduceOp.MIN,
108
+ group=self.tp_group,
109
+ )
110
+ for _ in range(queue_size.item()):
111
+ ack_id = self.cache_controller.ack_write_queue.get()
112
+ self.dec_lock_ref(self.ongoing_write_through[ack_id])
113
+ del self.ongoing_write_through[ack_id]
92
114
 
93
115
  def loading_check(self):
94
116
  while not self.cache_controller.ack_load_queue.empty():
@@ -106,11 +128,9 @@ class HiRadixCache(RadixCache):
106
128
  break
107
129
 
108
130
  def evictable_size(self):
109
- self.writing_check()
110
- self.loading_check()
111
131
  return self.evictable_size_
112
132
 
113
- def evict(self, num_tokens: int, evict_callback=None):
133
+ def evict(self, num_tokens: int):
114
134
  leaves = self._collect_leaves_device()
115
135
  heapq.heapify(leaves)
116
136
 
@@ -160,7 +180,7 @@ class HiRadixCache(RadixCache):
160
180
 
161
181
  def _evict_write_through_selective(self, node: TreeNode):
162
182
  # evict a node not initiated write to host
163
- self.cache_controller.mem_pool_device.free(node.value)
183
+ self.cache_controller.mem_pool_device_allocator.free(node.value)
164
184
  num_evicted = len(node.value)
165
185
  self._delete_leaf(node)
166
186
  return num_evicted
@@ -240,10 +260,6 @@ class HiRadixCache(RadixCache):
240
260
 
241
261
  return device_indices
242
262
 
243
- def loading_complete(self, node: TreeNode):
244
- self.loading_check()
245
- return node.loading == False
246
-
247
263
  def init_load_back(
248
264
  self,
249
265
  last_node: TreeNode,
@@ -270,28 +286,49 @@ class HiRadixCache(RadixCache):
270
286
 
271
287
  return last_node, prefix_indices
272
288
 
273
- def _match_prefix_helper(
274
- self, node: TreeNode, key: List, value, last_node: TreeNode
275
- ):
276
- node.last_access_time = time.time()
277
- if len(key) == 0:
278
- return
289
+ def read_to_load_cache(self):
290
+ self.load_cache_event.set()
279
291
 
280
- if key[0] in node.children.keys():
292
+ def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
293
+ if self.disable:
294
+ return [], self.root_node
295
+
296
+ value, last_node = self._match_prefix_helper(self.root_node, key)
297
+ if value:
298
+ value = torch.concat(value)
299
+ else:
300
+ value = torch.tensor([], dtype=torch.int32)
301
+
302
+ last_node_global = last_node
303
+ while last_node.evicted:
304
+ last_node = last_node.parent
305
+
306
+ if include_evicted:
307
+ return value, last_node, last_node_global
308
+ else:
309
+ return value, last_node
310
+
311
+ def _match_prefix_helper(self, node: TreeNode, key: List):
312
+ node.last_access_time = time.time()
313
+ value = []
314
+ while len(key) > 0 and key[0] in node.children.keys():
281
315
  child = node.children[key[0]]
316
+ child.last_access_time = time.time()
282
317
  prefix_len = _key_match(child.key, key)
283
318
  if prefix_len < len(child.key):
284
319
  new_node = self._split_node(child.key, child, prefix_len)
285
320
  self.inc_hit_count(new_node)
286
321
  if not new_node.evicted:
287
322
  value.append(new_node.value)
288
- last_node[0] = new_node
323
+ node = new_node
324
+ break
289
325
  else:
290
326
  self.inc_hit_count(child)
291
327
  if not child.evicted:
292
328
  value.append(child.value)
293
- last_node[0] = child
294
- self._match_prefix_helper(child, key[prefix_len:], value, last_node)
329
+ node = child
330
+ key = key[prefix_len:]
331
+ return value, node
295
332
 
296
333
  def _split_node(self, key, child: TreeNode, split_len: int):
297
334
  # child node split into new_node -> child
@@ -129,6 +129,7 @@ class TokenToKVPoolAllocator:
129
129
  self.size = size
130
130
  self.dtype = dtype
131
131
  self.device = device
132
+ self.page_size = 1
132
133
 
133
134
  self.free_slots = None
134
135
  self.is_not_in_free_group = True
@@ -149,15 +150,14 @@ class TokenToKVPoolAllocator:
149
150
 
150
151
  select_index = self.free_slots[:need_size]
151
152
  self.free_slots = self.free_slots[need_size:]
152
-
153
- return select_index.to(self.device, non_blocking=True)
153
+ return select_index
154
154
 
155
155
  def free(self, free_index: torch.Tensor):
156
156
  if free_index.numel() == 0:
157
157
  return
158
158
 
159
159
  if self.is_not_in_free_group:
160
- self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
160
+ self.free_slots = torch.concat((self.free_slots, free_index))
161
161
  else:
162
162
  self.free_group.append(free_index)
163
163
 
@@ -172,7 +172,9 @@ class TokenToKVPoolAllocator:
172
172
 
173
173
  def clear(self):
174
174
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
175
- self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32)
175
+ self.free_slots = torch.arange(
176
+ 1, self.size + 1, dtype=torch.int64, device=self.device
177
+ )
176
178
  self.is_in_free_group = False
177
179
  self.free_group = []
178
180
 
@@ -182,6 +184,7 @@ class MHATokenToKVPool(KVCache):
182
184
  def __init__(
183
185
  self,
184
186
  size: int,
187
+ page_size: int,
185
188
  dtype: torch.dtype,
186
189
  head_num: int,
187
190
  head_dim: int,
@@ -190,6 +193,7 @@ class MHATokenToKVPool(KVCache):
190
193
  enable_memory_saver: bool,
191
194
  ):
192
195
  self.size = size
196
+ self.page_size = page_size
193
197
  self.dtype = dtype
194
198
  self.device = device
195
199
  if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
@@ -206,6 +210,10 @@ class MHATokenToKVPool(KVCache):
206
210
  self.layer_num = layer_num
207
211
  self._create_buffers()
208
212
 
213
+ self.layer_transfer_counter = None
214
+ self.capture_mode = False
215
+ self.alt_stream = torch.cuda.Stream()
216
+
209
217
  k_size, v_size = self.get_kv_size_bytes()
210
218
  logger.info(
211
219
  f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
@@ -216,16 +224,16 @@ class MHATokenToKVPool(KVCache):
216
224
  # [size, head_num, head_dim] for each layer
217
225
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
218
226
  self.k_buffer = [
219
- torch.empty(
220
- (self.size + 1, self.head_num, self.head_dim),
227
+ torch.zeros(
228
+ (self.size + self.page_size, self.head_num, self.head_dim),
221
229
  dtype=self.store_dtype,
222
230
  device=self.device,
223
231
  )
224
232
  for _ in range(self.layer_num)
225
233
  ]
226
234
  self.v_buffer = [
227
- torch.empty(
228
- (self.size + 1, self.head_num, self.head_dim),
235
+ torch.zeros(
236
+ (self.size + self.page_size, self.head_num, self.head_dim),
229
237
  dtype=self.store_dtype,
230
238
  device=self.device,
231
239
  )
@@ -267,12 +275,28 @@ class MHATokenToKVPool(KVCache):
267
275
  self.k_buffer[i][indices] = k_data[i]
268
276
  self.v_buffer[i][indices] = v_data[i]
269
277
 
278
+ def register_layer_transfer_counter(self, layer_transfer_counter):
279
+ self.layer_transfer_counter = layer_transfer_counter
280
+
281
+ def transfer_per_layer(self, indices, flat_data, layer_id):
282
+ # transfer prepared data from host to device
283
+ flat_data = flat_data.to(device=self.device, non_blocking=False)
284
+ k_data, v_data = flat_data[0], flat_data[1]
285
+ self.k_buffer[layer_id][indices] = k_data
286
+ self.v_buffer[layer_id][indices] = v_data
287
+
270
288
  def get_key_buffer(self, layer_id: int):
289
+ if self.layer_transfer_counter is not None:
290
+ self.layer_transfer_counter.wait_until(layer_id)
291
+
271
292
  if self.store_dtype != self.dtype:
272
293
  return self.k_buffer[layer_id].view(self.dtype)
273
294
  return self.k_buffer[layer_id]
274
295
 
275
296
  def get_value_buffer(self, layer_id: int):
297
+ if self.layer_transfer_counter is not None:
298
+ self.layer_transfer_counter.wait_until(layer_id)
299
+
276
300
  if self.store_dtype != self.dtype:
277
301
  return self.v_buffer[layer_id].view(self.dtype)
278
302
  return self.v_buffer[layer_id]
@@ -297,14 +321,44 @@ class MHATokenToKVPool(KVCache):
297
321
  cache_v.div_(v_scale)
298
322
  cache_k = cache_k.to(self.dtype)
299
323
  cache_v = cache_v.to(self.dtype)
324
+
300
325
  if self.store_dtype != self.dtype:
301
- self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
302
- self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
326
+ cache_k = cache_k.view(self.store_dtype)
327
+ cache_v = cache_v.view(self.store_dtype)
328
+
329
+ if self.capture_mode and cache_k.shape[0] < 4:
330
+ self.alt_stream.wait_stream(torch.cuda.current_stream())
331
+ with torch.cuda.stream(self.alt_stream):
332
+ self.k_buffer[layer_id][loc] = cache_k
333
+ self.v_buffer[layer_id][loc] = cache_v
334
+ torch.cuda.current_stream().wait_stream(self.alt_stream)
303
335
  else:
304
336
  self.k_buffer[layer_id][loc] = cache_k
305
337
  self.v_buffer[layer_id][loc] = cache_v
306
338
 
307
339
 
340
+ @torch.compile
341
+ def fused_downcast(
342
+ cache_k: torch.Tensor,
343
+ cache_v: torch.Tensor,
344
+ k_scale: torch.Tensor,
345
+ v_scale: torch.Tensor,
346
+ dtype: torch.dtype,
347
+ store_dtype: torch.dtype,
348
+ max_fp8: float,
349
+ min_fp8: float,
350
+ ):
351
+ cache_k = cache_k / k_scale
352
+ cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
353
+ cache_v = cache_v / v_scale
354
+ cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
355
+ cache_k = cache_k.to(dtype)
356
+ cache_v = cache_v.to(dtype)
357
+ cache_k = cache_k.view(store_dtype)
358
+ cache_v = cache_v.view(store_dtype)
359
+ return cache_k, cache_v
360
+
361
+
308
362
  # This compiled version is slower in the unit test
309
363
  # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
310
364
  @torch.compile(dynamic=True, backend=get_compiler_backend())
@@ -317,6 +371,7 @@ class MLATokenToKVPool(KVCache):
317
371
  def __init__(
318
372
  self,
319
373
  size: int,
374
+ page_size: int,
320
375
  dtype: torch.dtype,
321
376
  kv_lora_rank: int,
322
377
  qk_rope_head_dim: int,
@@ -341,8 +396,8 @@ class MLATokenToKVPool(KVCache):
341
396
  with memory_saver_adapter.region():
342
397
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
343
398
  self.kv_buffer = [
344
- torch.empty(
345
- (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
399
+ torch.zeros(
400
+ (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
346
401
  dtype=self.store_dtype,
347
402
  device=device,
348
403
  )
@@ -382,6 +437,7 @@ class DoubleSparseTokenToKVPool(KVCache):
382
437
  def __init__(
383
438
  self,
384
439
  size: int,
440
+ page_size: int,
385
441
  dtype: torch.dtype,
386
442
  head_num: int,
387
443
  head_dim: int,
@@ -391,6 +447,7 @@ class DoubleSparseTokenToKVPool(KVCache):
391
447
  enable_memory_saver: bool,
392
448
  ):
393
449
  self.size = size
450
+ self.page_size = page_size
394
451
  self.dtype = dtype
395
452
  self.device = device
396
453
  if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
@@ -405,17 +462,21 @@ class DoubleSparseTokenToKVPool(KVCache):
405
462
  with memory_saver_adapter.region():
406
463
  # [size, head_num, head_dim] for each layer
407
464
  self.k_buffer = [
408
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
465
+ torch.zeros(
466
+ (size + page_size, head_num, head_dim), dtype=dtype, device=device
467
+ )
409
468
  for _ in range(layer_num)
410
469
  ]
411
470
  self.v_buffer = [
412
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
471
+ torch.zeros(
472
+ (size + page_size, head_num, head_dim), dtype=dtype, device=device
473
+ )
413
474
  for _ in range(layer_num)
414
475
  ]
415
476
 
416
477
  # [size, head_num, heavy_channel_num] for each layer
417
478
  self.label_buffer = [
418
- torch.empty(
479
+ torch.zeros(
419
480
  (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
420
481
  )
421
482
  for _ in range(layer_num)
@@ -470,7 +531,7 @@ class MHATokenToKVPoolHost:
470
531
  def __init__(
471
532
  self,
472
533
  device_pool: MHATokenToKVPool,
473
- host_to_device_ratio: float = 2.0,
534
+ host_to_device_ratio: float = 3.0,
474
535
  pin_memory: bool = False, # no need to use pin memory with the double buffering
475
536
  device: str = "cpu",
476
537
  ):
@@ -510,7 +571,7 @@ class MHATokenToKVPoolHost:
510
571
  f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
511
572
  )
512
573
 
513
- self.kv_buffer = torch.empty(
574
+ self.kv_buffer = torch.zeros(
514
575
  (2, self.layer_num, self.size, self.head_num, self.head_dim),
515
576
  dtype=self.dtype,
516
577
  device=self.device,
@@ -530,6 +591,9 @@ class MHATokenToKVPoolHost:
530
591
  def get_flat_data(self, indices):
531
592
  return self.kv_buffer[:, :, indices]
532
593
 
594
+ def get_flat_data_by_layer(self, indices, layer_id):
595
+ return self.kv_buffer[:, layer_id, indices]
596
+
533
597
  def assign_flat_data(self, indices, flat_data):
534
598
  self.kv_buffer[:, :, indices] = flat_data
535
599