sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
4
|
-
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
|
5
6
|
|
6
7
|
import torch
|
7
8
|
|
@@ -24,73 +25,40 @@ class ChunkCache(BasePrefixCache):
|
|
24
25
|
req_to_token_pool: ReqToTokenPool,
|
25
26
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
26
27
|
):
|
27
|
-
self.disable = True
|
28
28
|
self.req_to_token_pool = req_to_token_pool
|
29
29
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
30
|
-
self.entries: Dict[str, ChunkCacheEntry] = {}
|
31
|
-
|
32
|
-
self.reset()
|
33
30
|
|
34
31
|
def reset(self):
|
35
|
-
|
36
|
-
|
37
|
-
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
|
38
|
-
if rid not in self.entries:
|
39
|
-
return [], None
|
40
|
-
|
41
|
-
entry = self.entries[rid]
|
42
|
-
max_prefix_len = len(key)
|
43
|
-
return entry.value[:max_prefix_len], entry
|
32
|
+
pass
|
44
33
|
|
45
|
-
def
|
46
|
-
|
47
|
-
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
48
|
-
else:
|
49
|
-
token_id_len = len(token_ids)
|
34
|
+
def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
|
35
|
+
return [], None
|
50
36
|
|
37
|
+
def cache_finished_req(self, req: Req):
|
51
38
|
kv_indices = self.req_to_token_pool.req_to_token[
|
52
|
-
req.req_pool_idx, :
|
39
|
+
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
53
40
|
]
|
54
41
|
self.req_to_token_pool.free(req.req_pool_idx)
|
55
42
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
56
43
|
|
57
|
-
if req.rid in self.entries:
|
58
|
-
del self.entries[req.rid]
|
59
|
-
|
60
44
|
def cache_unfinished_req(self, req: Req):
|
61
|
-
token_id_len = len(req.fill_ids)
|
62
|
-
|
63
45
|
kv_indices = self.req_to_token_pool.req_to_token[
|
64
|
-
req.req_pool_idx, :
|
46
|
+
req.req_pool_idx, : len(req.fill_ids)
|
65
47
|
]
|
66
48
|
|
67
|
-
|
68
|
-
self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
|
69
|
-
|
70
|
-
entry = self.entries[req.rid]
|
71
|
-
entry.value = kv_indices
|
49
|
+
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
72
50
|
req.prefix_indices = kv_indices
|
73
|
-
req.last_node = entry
|
74
51
|
|
75
52
|
def insert(self):
|
76
53
|
raise NotImplementedError()
|
77
54
|
|
78
|
-
def evict(self, num_tokens: int
|
55
|
+
def evict(self, num_tokens: int):
|
79
56
|
pass
|
80
57
|
|
81
|
-
def inc_lock_ref(self, node):
|
58
|
+
def inc_lock_ref(self, node: Any):
|
82
59
|
return 0
|
83
60
|
|
84
|
-
def dec_lock_ref(self, node):
|
85
|
-
return 0
|
86
|
-
|
87
|
-
def evictable_size(self):
|
88
|
-
return 0
|
89
|
-
|
90
|
-
def pretty_print(self):
|
91
|
-
return ""
|
92
|
-
|
93
|
-
def protected_size(self):
|
61
|
+
def dec_lock_ref(self, node: Any):
|
94
62
|
return 0
|
95
63
|
|
96
64
|
def pretty_print(self):
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import heapq
|
2
2
|
import logging
|
3
|
+
import threading
|
3
4
|
import time
|
4
5
|
from typing import List, Optional
|
5
6
|
|
@@ -7,11 +8,12 @@ import torch
|
|
7
8
|
|
8
9
|
from sglang.srt.managers.cache_controller import HiCacheController
|
9
10
|
from sglang.srt.mem_cache.memory_pool import (
|
10
|
-
MHATokenToKVPool,
|
11
11
|
MHATokenToKVPoolHost,
|
12
12
|
ReqToTokenPool,
|
13
|
+
TokenToKVPoolAllocator,
|
13
14
|
)
|
14
|
-
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
15
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
16
|
+
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
|
15
17
|
|
16
18
|
logger = logging.getLogger(__name__)
|
17
19
|
|
@@ -21,11 +23,25 @@ class HiRadixCache(RadixCache):
|
|
21
23
|
def __init__(
|
22
24
|
self,
|
23
25
|
req_to_token_pool: ReqToTokenPool,
|
24
|
-
|
26
|
+
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
27
|
+
tp_cache_group: torch.distributed.ProcessGroup,
|
28
|
+
page_size: int,
|
25
29
|
):
|
26
|
-
|
30
|
+
if page_size != 1:
|
31
|
+
raise ValueError(
|
32
|
+
"Page size larger than 1 is not yet supported in HiRadixCache."
|
33
|
+
)
|
34
|
+
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
35
|
+
token_to_kv_pool_allocator.get_kvcache()
|
36
|
+
)
|
37
|
+
self.tp_group = tp_cache_group
|
38
|
+
self.page_size = page_size
|
39
|
+
|
40
|
+
self.load_cache_event = threading.Event()
|
27
41
|
self.cache_controller = HiCacheController(
|
28
|
-
|
42
|
+
token_to_kv_pool_allocator,
|
43
|
+
self.token_to_kv_pool_host,
|
44
|
+
load_cache_event=self.load_cache_event,
|
29
45
|
)
|
30
46
|
|
31
47
|
# record the nodes with ongoing write through
|
@@ -35,7 +51,9 @@ class HiRadixCache(RadixCache):
|
|
35
51
|
# todo: dynamically adjust the threshold
|
36
52
|
self.write_through_threshold = 1
|
37
53
|
self.load_back_threshold = 10
|
38
|
-
super().__init__(
|
54
|
+
super().__init__(
|
55
|
+
req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
|
56
|
+
)
|
39
57
|
|
40
58
|
def reset(self):
|
41
59
|
TreeNode.counter = 0
|
@@ -53,14 +71,12 @@ class HiRadixCache(RadixCache):
|
|
53
71
|
def write_backup(self, node: TreeNode):
|
54
72
|
host_indices = self.cache_controller.write(
|
55
73
|
device_indices=node.value,
|
56
|
-
priority=-self.get_height(node),
|
57
74
|
node_id=node.id,
|
58
75
|
)
|
59
76
|
if host_indices is None:
|
60
77
|
self.evict_host(len(node.value))
|
61
78
|
host_indices = self.cache_controller.write(
|
62
79
|
device_indices=node.value,
|
63
|
-
priority=-self.get_height(node),
|
64
80
|
node_id=node.id,
|
65
81
|
)
|
66
82
|
if host_indices is not None:
|
@@ -81,14 +97,20 @@ class HiRadixCache(RadixCache):
|
|
81
97
|
node.hit_count = 0
|
82
98
|
|
83
99
|
def writing_check(self):
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
100
|
+
queue_size = torch.tensor(
|
101
|
+
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
102
|
+
)
|
103
|
+
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
104
|
+
# synchrnoize TP workers to make the same update to radix cache
|
105
|
+
torch.distributed.all_reduce(
|
106
|
+
queue_size,
|
107
|
+
op=torch.distributed.ReduceOp.MIN,
|
108
|
+
group=self.tp_group,
|
109
|
+
)
|
110
|
+
for _ in range(queue_size.item()):
|
111
|
+
ack_id = self.cache_controller.ack_write_queue.get()
|
112
|
+
self.dec_lock_ref(self.ongoing_write_through[ack_id])
|
113
|
+
del self.ongoing_write_through[ack_id]
|
92
114
|
|
93
115
|
def loading_check(self):
|
94
116
|
while not self.cache_controller.ack_load_queue.empty():
|
@@ -106,11 +128,9 @@ class HiRadixCache(RadixCache):
|
|
106
128
|
break
|
107
129
|
|
108
130
|
def evictable_size(self):
|
109
|
-
self.writing_check()
|
110
|
-
self.loading_check()
|
111
131
|
return self.evictable_size_
|
112
132
|
|
113
|
-
def evict(self, num_tokens: int
|
133
|
+
def evict(self, num_tokens: int):
|
114
134
|
leaves = self._collect_leaves_device()
|
115
135
|
heapq.heapify(leaves)
|
116
136
|
|
@@ -160,7 +180,7 @@ class HiRadixCache(RadixCache):
|
|
160
180
|
|
161
181
|
def _evict_write_through_selective(self, node: TreeNode):
|
162
182
|
# evict a node not initiated write to host
|
163
|
-
self.cache_controller.
|
183
|
+
self.cache_controller.mem_pool_device_allocator.free(node.value)
|
164
184
|
num_evicted = len(node.value)
|
165
185
|
self._delete_leaf(node)
|
166
186
|
return num_evicted
|
@@ -240,10 +260,6 @@ class HiRadixCache(RadixCache):
|
|
240
260
|
|
241
261
|
return device_indices
|
242
262
|
|
243
|
-
def loading_complete(self, node: TreeNode):
|
244
|
-
self.loading_check()
|
245
|
-
return node.loading == False
|
246
|
-
|
247
263
|
def init_load_back(
|
248
264
|
self,
|
249
265
|
last_node: TreeNode,
|
@@ -270,28 +286,49 @@ class HiRadixCache(RadixCache):
|
|
270
286
|
|
271
287
|
return last_node, prefix_indices
|
272
288
|
|
273
|
-
def
|
274
|
-
self
|
275
|
-
):
|
276
|
-
node.last_access_time = time.time()
|
277
|
-
if len(key) == 0:
|
278
|
-
return
|
289
|
+
def read_to_load_cache(self):
|
290
|
+
self.load_cache_event.set()
|
279
291
|
|
280
|
-
|
292
|
+
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
293
|
+
if self.disable:
|
294
|
+
return [], self.root_node
|
295
|
+
|
296
|
+
value, last_node = self._match_prefix_helper(self.root_node, key)
|
297
|
+
if value:
|
298
|
+
value = torch.concat(value)
|
299
|
+
else:
|
300
|
+
value = torch.tensor([], dtype=torch.int32)
|
301
|
+
|
302
|
+
last_node_global = last_node
|
303
|
+
while last_node.evicted:
|
304
|
+
last_node = last_node.parent
|
305
|
+
|
306
|
+
if include_evicted:
|
307
|
+
return value, last_node, last_node_global
|
308
|
+
else:
|
309
|
+
return value, last_node
|
310
|
+
|
311
|
+
def _match_prefix_helper(self, node: TreeNode, key: List):
|
312
|
+
node.last_access_time = time.time()
|
313
|
+
value = []
|
314
|
+
while len(key) > 0 and key[0] in node.children.keys():
|
281
315
|
child = node.children[key[0]]
|
316
|
+
child.last_access_time = time.time()
|
282
317
|
prefix_len = _key_match(child.key, key)
|
283
318
|
if prefix_len < len(child.key):
|
284
319
|
new_node = self._split_node(child.key, child, prefix_len)
|
285
320
|
self.inc_hit_count(new_node)
|
286
321
|
if not new_node.evicted:
|
287
322
|
value.append(new_node.value)
|
288
|
-
|
323
|
+
node = new_node
|
324
|
+
break
|
289
325
|
else:
|
290
326
|
self.inc_hit_count(child)
|
291
327
|
if not child.evicted:
|
292
328
|
value.append(child.value)
|
293
|
-
|
294
|
-
|
329
|
+
node = child
|
330
|
+
key = key[prefix_len:]
|
331
|
+
return value, node
|
295
332
|
|
296
333
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
297
334
|
# child node split into new_node -> child
|
@@ -129,6 +129,7 @@ class TokenToKVPoolAllocator:
|
|
129
129
|
self.size = size
|
130
130
|
self.dtype = dtype
|
131
131
|
self.device = device
|
132
|
+
self.page_size = 1
|
132
133
|
|
133
134
|
self.free_slots = None
|
134
135
|
self.is_not_in_free_group = True
|
@@ -149,15 +150,14 @@ class TokenToKVPoolAllocator:
|
|
149
150
|
|
150
151
|
select_index = self.free_slots[:need_size]
|
151
152
|
self.free_slots = self.free_slots[need_size:]
|
152
|
-
|
153
|
-
return select_index.to(self.device, non_blocking=True)
|
153
|
+
return select_index
|
154
154
|
|
155
155
|
def free(self, free_index: torch.Tensor):
|
156
156
|
if free_index.numel() == 0:
|
157
157
|
return
|
158
158
|
|
159
159
|
if self.is_not_in_free_group:
|
160
|
-
self.free_slots = torch.concat((self.free_slots, free_index
|
160
|
+
self.free_slots = torch.concat((self.free_slots, free_index))
|
161
161
|
else:
|
162
162
|
self.free_group.append(free_index)
|
163
163
|
|
@@ -172,7 +172,9 @@ class TokenToKVPoolAllocator:
|
|
172
172
|
|
173
173
|
def clear(self):
|
174
174
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
175
|
-
self.free_slots = torch.arange(
|
175
|
+
self.free_slots = torch.arange(
|
176
|
+
1, self.size + 1, dtype=torch.int64, device=self.device
|
177
|
+
)
|
176
178
|
self.is_in_free_group = False
|
177
179
|
self.free_group = []
|
178
180
|
|
@@ -182,6 +184,7 @@ class MHATokenToKVPool(KVCache):
|
|
182
184
|
def __init__(
|
183
185
|
self,
|
184
186
|
size: int,
|
187
|
+
page_size: int,
|
185
188
|
dtype: torch.dtype,
|
186
189
|
head_num: int,
|
187
190
|
head_dim: int,
|
@@ -190,6 +193,7 @@ class MHATokenToKVPool(KVCache):
|
|
190
193
|
enable_memory_saver: bool,
|
191
194
|
):
|
192
195
|
self.size = size
|
196
|
+
self.page_size = page_size
|
193
197
|
self.dtype = dtype
|
194
198
|
self.device = device
|
195
199
|
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
@@ -206,6 +210,10 @@ class MHATokenToKVPool(KVCache):
|
|
206
210
|
self.layer_num = layer_num
|
207
211
|
self._create_buffers()
|
208
212
|
|
213
|
+
self.layer_transfer_counter = None
|
214
|
+
self.capture_mode = False
|
215
|
+
self.alt_stream = torch.cuda.Stream()
|
216
|
+
|
209
217
|
k_size, v_size = self.get_kv_size_bytes()
|
210
218
|
logger.info(
|
211
219
|
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
@@ -216,16 +224,16 @@ class MHATokenToKVPool(KVCache):
|
|
216
224
|
# [size, head_num, head_dim] for each layer
|
217
225
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
218
226
|
self.k_buffer = [
|
219
|
-
torch.
|
220
|
-
(self.size +
|
227
|
+
torch.zeros(
|
228
|
+
(self.size + self.page_size, self.head_num, self.head_dim),
|
221
229
|
dtype=self.store_dtype,
|
222
230
|
device=self.device,
|
223
231
|
)
|
224
232
|
for _ in range(self.layer_num)
|
225
233
|
]
|
226
234
|
self.v_buffer = [
|
227
|
-
torch.
|
228
|
-
(self.size +
|
235
|
+
torch.zeros(
|
236
|
+
(self.size + self.page_size, self.head_num, self.head_dim),
|
229
237
|
dtype=self.store_dtype,
|
230
238
|
device=self.device,
|
231
239
|
)
|
@@ -267,12 +275,28 @@ class MHATokenToKVPool(KVCache):
|
|
267
275
|
self.k_buffer[i][indices] = k_data[i]
|
268
276
|
self.v_buffer[i][indices] = v_data[i]
|
269
277
|
|
278
|
+
def register_layer_transfer_counter(self, layer_transfer_counter):
|
279
|
+
self.layer_transfer_counter = layer_transfer_counter
|
280
|
+
|
281
|
+
def transfer_per_layer(self, indices, flat_data, layer_id):
|
282
|
+
# transfer prepared data from host to device
|
283
|
+
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
284
|
+
k_data, v_data = flat_data[0], flat_data[1]
|
285
|
+
self.k_buffer[layer_id][indices] = k_data
|
286
|
+
self.v_buffer[layer_id][indices] = v_data
|
287
|
+
|
270
288
|
def get_key_buffer(self, layer_id: int):
|
289
|
+
if self.layer_transfer_counter is not None:
|
290
|
+
self.layer_transfer_counter.wait_until(layer_id)
|
291
|
+
|
271
292
|
if self.store_dtype != self.dtype:
|
272
293
|
return self.k_buffer[layer_id].view(self.dtype)
|
273
294
|
return self.k_buffer[layer_id]
|
274
295
|
|
275
296
|
def get_value_buffer(self, layer_id: int):
|
297
|
+
if self.layer_transfer_counter is not None:
|
298
|
+
self.layer_transfer_counter.wait_until(layer_id)
|
299
|
+
|
276
300
|
if self.store_dtype != self.dtype:
|
277
301
|
return self.v_buffer[layer_id].view(self.dtype)
|
278
302
|
return self.v_buffer[layer_id]
|
@@ -297,14 +321,44 @@ class MHATokenToKVPool(KVCache):
|
|
297
321
|
cache_v.div_(v_scale)
|
298
322
|
cache_k = cache_k.to(self.dtype)
|
299
323
|
cache_v = cache_v.to(self.dtype)
|
324
|
+
|
300
325
|
if self.store_dtype != self.dtype:
|
301
|
-
|
302
|
-
|
326
|
+
cache_k = cache_k.view(self.store_dtype)
|
327
|
+
cache_v = cache_v.view(self.store_dtype)
|
328
|
+
|
329
|
+
if self.capture_mode and cache_k.shape[0] < 4:
|
330
|
+
self.alt_stream.wait_stream(torch.cuda.current_stream())
|
331
|
+
with torch.cuda.stream(self.alt_stream):
|
332
|
+
self.k_buffer[layer_id][loc] = cache_k
|
333
|
+
self.v_buffer[layer_id][loc] = cache_v
|
334
|
+
torch.cuda.current_stream().wait_stream(self.alt_stream)
|
303
335
|
else:
|
304
336
|
self.k_buffer[layer_id][loc] = cache_k
|
305
337
|
self.v_buffer[layer_id][loc] = cache_v
|
306
338
|
|
307
339
|
|
340
|
+
@torch.compile
|
341
|
+
def fused_downcast(
|
342
|
+
cache_k: torch.Tensor,
|
343
|
+
cache_v: torch.Tensor,
|
344
|
+
k_scale: torch.Tensor,
|
345
|
+
v_scale: torch.Tensor,
|
346
|
+
dtype: torch.dtype,
|
347
|
+
store_dtype: torch.dtype,
|
348
|
+
max_fp8: float,
|
349
|
+
min_fp8: float,
|
350
|
+
):
|
351
|
+
cache_k = cache_k / k_scale
|
352
|
+
cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
|
353
|
+
cache_v = cache_v / v_scale
|
354
|
+
cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
|
355
|
+
cache_k = cache_k.to(dtype)
|
356
|
+
cache_v = cache_v.to(dtype)
|
357
|
+
cache_k = cache_k.view(store_dtype)
|
358
|
+
cache_v = cache_v.view(store_dtype)
|
359
|
+
return cache_k, cache_v
|
360
|
+
|
361
|
+
|
308
362
|
# This compiled version is slower in the unit test
|
309
363
|
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
310
364
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
@@ -317,6 +371,7 @@ class MLATokenToKVPool(KVCache):
|
|
317
371
|
def __init__(
|
318
372
|
self,
|
319
373
|
size: int,
|
374
|
+
page_size: int,
|
320
375
|
dtype: torch.dtype,
|
321
376
|
kv_lora_rank: int,
|
322
377
|
qk_rope_head_dim: int,
|
@@ -341,8 +396,8 @@ class MLATokenToKVPool(KVCache):
|
|
341
396
|
with memory_saver_adapter.region():
|
342
397
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
343
398
|
self.kv_buffer = [
|
344
|
-
torch.
|
345
|
-
(size +
|
399
|
+
torch.zeros(
|
400
|
+
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
|
346
401
|
dtype=self.store_dtype,
|
347
402
|
device=device,
|
348
403
|
)
|
@@ -382,6 +437,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
382
437
|
def __init__(
|
383
438
|
self,
|
384
439
|
size: int,
|
440
|
+
page_size: int,
|
385
441
|
dtype: torch.dtype,
|
386
442
|
head_num: int,
|
387
443
|
head_dim: int,
|
@@ -391,6 +447,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
391
447
|
enable_memory_saver: bool,
|
392
448
|
):
|
393
449
|
self.size = size
|
450
|
+
self.page_size = page_size
|
394
451
|
self.dtype = dtype
|
395
452
|
self.device = device
|
396
453
|
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
@@ -405,17 +462,21 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
405
462
|
with memory_saver_adapter.region():
|
406
463
|
# [size, head_num, head_dim] for each layer
|
407
464
|
self.k_buffer = [
|
408
|
-
torch.
|
465
|
+
torch.zeros(
|
466
|
+
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
467
|
+
)
|
409
468
|
for _ in range(layer_num)
|
410
469
|
]
|
411
470
|
self.v_buffer = [
|
412
|
-
torch.
|
471
|
+
torch.zeros(
|
472
|
+
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
473
|
+
)
|
413
474
|
for _ in range(layer_num)
|
414
475
|
]
|
415
476
|
|
416
477
|
# [size, head_num, heavy_channel_num] for each layer
|
417
478
|
self.label_buffer = [
|
418
|
-
torch.
|
479
|
+
torch.zeros(
|
419
480
|
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
420
481
|
)
|
421
482
|
for _ in range(layer_num)
|
@@ -470,7 +531,7 @@ class MHATokenToKVPoolHost:
|
|
470
531
|
def __init__(
|
471
532
|
self,
|
472
533
|
device_pool: MHATokenToKVPool,
|
473
|
-
host_to_device_ratio: float =
|
534
|
+
host_to_device_ratio: float = 3.0,
|
474
535
|
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
475
536
|
device: str = "cpu",
|
476
537
|
):
|
@@ -510,7 +571,7 @@ class MHATokenToKVPoolHost:
|
|
510
571
|
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
511
572
|
)
|
512
573
|
|
513
|
-
self.kv_buffer = torch.
|
574
|
+
self.kv_buffer = torch.zeros(
|
514
575
|
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
515
576
|
dtype=self.dtype,
|
516
577
|
device=self.device,
|
@@ -530,6 +591,9 @@ class MHATokenToKVPoolHost:
|
|
530
591
|
def get_flat_data(self, indices):
|
531
592
|
return self.kv_buffer[:, :, indices]
|
532
593
|
|
594
|
+
def get_flat_data_by_layer(self, indices, layer_id):
|
595
|
+
return self.kv_buffer[:, layer_id, indices]
|
596
|
+
|
533
597
|
def assign_flat_data(self, indices, flat_data):
|
534
598
|
self.kv_buffer[:, :, indices] = flat_data
|
535
599
|
|