sglang 0.4.3.post3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +95 -49
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +5 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +72 -8
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +33 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +212 -117
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +258 -782
- sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
- sglang/srt/managers/tokenizer_manager.py +7 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +63 -34
- sglang/srt/mem_cache/memory_pool.py +112 -46
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/metrics/collector.py +8 -0
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +12 -8
- sglang/srt/model_executor/model_runner.py +153 -134
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +25 -19
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +37 -15
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +19 -20
- sglang/srt/speculative/build_eagle_tree.py +6 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
- sglang/srt/speculative/eagle_utils.py +2 -1
- sglang/srt/speculative/eagle_worker.py +109 -38
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -9
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/RECORD +128 -83
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import
|
2
|
+
from typing import Any, List, Tuple
|
3
3
|
|
4
4
|
|
5
5
|
class BasePrefixCache(ABC):
|
@@ -26,24 +26,22 @@ class BasePrefixCache(ABC):
|
|
26
26
|
pass
|
27
27
|
|
28
28
|
@abstractmethod
|
29
|
-
def evict(self, num_tokens: int
|
29
|
+
def evict(self, num_tokens: int):
|
30
30
|
pass
|
31
31
|
|
32
32
|
@abstractmethod
|
33
|
-
def inc_lock_ref(self, node):
|
33
|
+
def inc_lock_ref(self, node: Any):
|
34
34
|
pass
|
35
35
|
|
36
36
|
@abstractmethod
|
37
|
-
def dec_lock_ref(self, node):
|
37
|
+
def dec_lock_ref(self, node: Any):
|
38
38
|
pass
|
39
39
|
|
40
|
-
@abstractmethod
|
41
40
|
def evictable_size(self):
|
42
|
-
|
41
|
+
return 0
|
43
42
|
|
44
|
-
@abstractmethod
|
45
43
|
def protected_size(self):
|
46
|
-
|
44
|
+
return 0
|
47
45
|
|
48
46
|
def total_size(self):
|
49
47
|
raise NotImplementedError()
|
@@ -1,7 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
4
|
-
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
|
5
6
|
|
6
7
|
import torch
|
7
8
|
|
@@ -24,73 +25,40 @@ class ChunkCache(BasePrefixCache):
|
|
24
25
|
req_to_token_pool: ReqToTokenPool,
|
25
26
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
26
27
|
):
|
27
|
-
self.disable = True
|
28
28
|
self.req_to_token_pool = req_to_token_pool
|
29
29
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
30
|
-
self.entries: Dict[str, ChunkCacheEntry] = {}
|
31
|
-
|
32
|
-
self.reset()
|
33
30
|
|
34
31
|
def reset(self):
|
35
|
-
|
36
|
-
|
37
|
-
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
|
38
|
-
if rid not in self.entries:
|
39
|
-
return [], None
|
40
|
-
|
41
|
-
entry = self.entries[rid]
|
42
|
-
max_prefix_len = len(key)
|
43
|
-
return entry.value[:max_prefix_len], entry
|
32
|
+
pass
|
44
33
|
|
45
|
-
def
|
46
|
-
|
47
|
-
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
48
|
-
else:
|
49
|
-
token_id_len = len(token_ids)
|
34
|
+
def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
|
35
|
+
return [], None
|
50
36
|
|
37
|
+
def cache_finished_req(self, req: Req):
|
51
38
|
kv_indices = self.req_to_token_pool.req_to_token[
|
52
|
-
req.req_pool_idx, :
|
39
|
+
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
53
40
|
]
|
54
41
|
self.req_to_token_pool.free(req.req_pool_idx)
|
55
42
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
56
43
|
|
57
|
-
if req.rid in self.entries:
|
58
|
-
del self.entries[req.rid]
|
59
|
-
|
60
44
|
def cache_unfinished_req(self, req: Req):
|
61
|
-
token_id_len = len(req.fill_ids)
|
62
|
-
|
63
45
|
kv_indices = self.req_to_token_pool.req_to_token[
|
64
|
-
req.req_pool_idx, :
|
46
|
+
req.req_pool_idx, : len(req.fill_ids)
|
65
47
|
]
|
66
48
|
|
67
|
-
|
68
|
-
self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
|
69
|
-
|
70
|
-
entry = self.entries[req.rid]
|
71
|
-
entry.value = kv_indices
|
49
|
+
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
72
50
|
req.prefix_indices = kv_indices
|
73
|
-
req.last_node = entry
|
74
51
|
|
75
52
|
def insert(self):
|
76
53
|
raise NotImplementedError()
|
77
54
|
|
78
|
-
def evict(self, num_tokens: int
|
55
|
+
def evict(self, num_tokens: int):
|
79
56
|
pass
|
80
57
|
|
81
|
-
def inc_lock_ref(self, node):
|
58
|
+
def inc_lock_ref(self, node: Any):
|
82
59
|
return 0
|
83
60
|
|
84
|
-
def dec_lock_ref(self, node):
|
85
|
-
return 0
|
86
|
-
|
87
|
-
def evictable_size(self):
|
88
|
-
return 0
|
89
|
-
|
90
|
-
def pretty_print(self):
|
91
|
-
return ""
|
92
|
-
|
93
|
-
def protected_size(self):
|
61
|
+
def dec_lock_ref(self, node: Any):
|
94
62
|
return 0
|
95
63
|
|
96
64
|
def pretty_print(self):
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import heapq
|
2
2
|
import logging
|
3
|
+
import threading
|
3
4
|
import time
|
4
5
|
from typing import List, Optional
|
5
6
|
|
@@ -7,11 +8,12 @@ import torch
|
|
7
8
|
|
8
9
|
from sglang.srt.managers.cache_controller import HiCacheController
|
9
10
|
from sglang.srt.mem_cache.memory_pool import (
|
10
|
-
MHATokenToKVPool,
|
11
11
|
MHATokenToKVPoolHost,
|
12
12
|
ReqToTokenPool,
|
13
|
+
TokenToKVPoolAllocator,
|
13
14
|
)
|
14
|
-
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
15
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
16
|
+
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
|
15
17
|
|
16
18
|
logger = logging.getLogger(__name__)
|
17
19
|
|
@@ -21,11 +23,19 @@ class HiRadixCache(RadixCache):
|
|
21
23
|
def __init__(
|
22
24
|
self,
|
23
25
|
req_to_token_pool: ReqToTokenPool,
|
24
|
-
|
26
|
+
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
27
|
+
tp_cache_group: torch.distributed.ProcessGroup,
|
25
28
|
):
|
26
|
-
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
29
|
+
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
30
|
+
token_to_kv_pool_allocator.get_kvcache()
|
31
|
+
)
|
32
|
+
self.tp_group = tp_cache_group
|
33
|
+
|
34
|
+
self.load_cache_event = threading.Event()
|
27
35
|
self.cache_controller = HiCacheController(
|
28
|
-
|
36
|
+
token_to_kv_pool_allocator,
|
37
|
+
self.token_to_kv_pool_host,
|
38
|
+
load_cache_event=self.load_cache_event,
|
29
39
|
)
|
30
40
|
|
31
41
|
# record the nodes with ongoing write through
|
@@ -35,7 +45,7 @@ class HiRadixCache(RadixCache):
|
|
35
45
|
# todo: dynamically adjust the threshold
|
36
46
|
self.write_through_threshold = 1
|
37
47
|
self.load_back_threshold = 10
|
38
|
-
super().__init__(req_to_token_pool,
|
48
|
+
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False)
|
39
49
|
|
40
50
|
def reset(self):
|
41
51
|
TreeNode.counter = 0
|
@@ -53,14 +63,12 @@ class HiRadixCache(RadixCache):
|
|
53
63
|
def write_backup(self, node: TreeNode):
|
54
64
|
host_indices = self.cache_controller.write(
|
55
65
|
device_indices=node.value,
|
56
|
-
priority=-self.get_height(node),
|
57
66
|
node_id=node.id,
|
58
67
|
)
|
59
68
|
if host_indices is None:
|
60
69
|
self.evict_host(len(node.value))
|
61
70
|
host_indices = self.cache_controller.write(
|
62
71
|
device_indices=node.value,
|
63
|
-
priority=-self.get_height(node),
|
64
72
|
node_id=node.id,
|
65
73
|
)
|
66
74
|
if host_indices is not None:
|
@@ -81,14 +89,20 @@ class HiRadixCache(RadixCache):
|
|
81
89
|
node.hit_count = 0
|
82
90
|
|
83
91
|
def writing_check(self):
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
+
queue_size = torch.tensor(
|
93
|
+
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
94
|
+
)
|
95
|
+
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
96
|
+
# synchrnoize TP workers to make the same update to radix cache
|
97
|
+
torch.distributed.all_reduce(
|
98
|
+
queue_size,
|
99
|
+
op=torch.distributed.ReduceOp.MIN,
|
100
|
+
group=self.tp_group,
|
101
|
+
)
|
102
|
+
for _ in range(queue_size.item()):
|
103
|
+
ack_id = self.cache_controller.ack_write_queue.get()
|
104
|
+
self.dec_lock_ref(self.ongoing_write_through[ack_id])
|
105
|
+
del self.ongoing_write_through[ack_id]
|
92
106
|
|
93
107
|
def loading_check(self):
|
94
108
|
while not self.cache_controller.ack_load_queue.empty():
|
@@ -106,11 +120,9 @@ class HiRadixCache(RadixCache):
|
|
106
120
|
break
|
107
121
|
|
108
122
|
def evictable_size(self):
|
109
|
-
self.writing_check()
|
110
|
-
self.loading_check()
|
111
123
|
return self.evictable_size_
|
112
124
|
|
113
|
-
def evict(self, num_tokens: int
|
125
|
+
def evict(self, num_tokens: int):
|
114
126
|
leaves = self._collect_leaves_device()
|
115
127
|
heapq.heapify(leaves)
|
116
128
|
|
@@ -160,7 +172,7 @@ class HiRadixCache(RadixCache):
|
|
160
172
|
|
161
173
|
def _evict_write_through_selective(self, node: TreeNode):
|
162
174
|
# evict a node not initiated write to host
|
163
|
-
self.cache_controller.
|
175
|
+
self.cache_controller.mem_pool_device_allocator.free(node.value)
|
164
176
|
num_evicted = len(node.value)
|
165
177
|
self._delete_leaf(node)
|
166
178
|
return num_evicted
|
@@ -240,10 +252,6 @@ class HiRadixCache(RadixCache):
|
|
240
252
|
|
241
253
|
return device_indices
|
242
254
|
|
243
|
-
def loading_complete(self, node: TreeNode):
|
244
|
-
self.loading_check()
|
245
|
-
return node.loading == False
|
246
|
-
|
247
255
|
def init_load_back(
|
248
256
|
self,
|
249
257
|
last_node: TreeNode,
|
@@ -270,28 +278,49 @@ class HiRadixCache(RadixCache):
|
|
270
278
|
|
271
279
|
return last_node, prefix_indices
|
272
280
|
|
273
|
-
def
|
274
|
-
self
|
275
|
-
):
|
276
|
-
node.last_access_time = time.time()
|
277
|
-
if len(key) == 0:
|
278
|
-
return
|
281
|
+
def read_to_load_cache(self):
|
282
|
+
self.load_cache_event.set()
|
279
283
|
|
280
|
-
|
284
|
+
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
285
|
+
if self.disable:
|
286
|
+
return [], self.root_node
|
287
|
+
|
288
|
+
value, last_node = self._match_prefix_helper(self.root_node, key)
|
289
|
+
if value:
|
290
|
+
value = torch.concat(value)
|
291
|
+
else:
|
292
|
+
value = torch.tensor([], dtype=torch.int32)
|
293
|
+
|
294
|
+
last_node_global = last_node
|
295
|
+
while last_node.evicted:
|
296
|
+
last_node = last_node.parent
|
297
|
+
|
298
|
+
if include_evicted:
|
299
|
+
return value, last_node, last_node_global
|
300
|
+
else:
|
301
|
+
return value, last_node
|
302
|
+
|
303
|
+
def _match_prefix_helper(self, node: TreeNode, key: List):
|
304
|
+
node.last_access_time = time.time()
|
305
|
+
value = []
|
306
|
+
while len(key) > 0 and key[0] in node.children.keys():
|
281
307
|
child = node.children[key[0]]
|
308
|
+
child.last_access_time = time.time()
|
282
309
|
prefix_len = _key_match(child.key, key)
|
283
310
|
if prefix_len < len(child.key):
|
284
311
|
new_node = self._split_node(child.key, child, prefix_len)
|
285
312
|
self.inc_hit_count(new_node)
|
286
313
|
if not new_node.evicted:
|
287
314
|
value.append(new_node.value)
|
288
|
-
|
315
|
+
node = new_node
|
316
|
+
break
|
289
317
|
else:
|
290
318
|
self.inc_hit_count(child)
|
291
319
|
if not child.evicted:
|
292
320
|
value.append(child.value)
|
293
|
-
|
294
|
-
|
321
|
+
node = child
|
322
|
+
key = key[prefix_len:]
|
323
|
+
return value, node
|
295
324
|
|
296
325
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
297
326
|
# child node split into new_node -> child
|
@@ -20,9 +20,8 @@ Memory pool.
|
|
20
20
|
|
21
21
|
SGLang has two levels of memory pool.
|
22
22
|
ReqToTokenPool maps a a request to its token locations.
|
23
|
-
TokenToKVPoolAllocator
|
24
|
-
KVCache actually holds the physical kv cache.
|
25
|
-
by TokenToKVPoolAllocator
|
23
|
+
TokenToKVPoolAllocator manages the indices to kv cache data.
|
24
|
+
KVCache actually holds the physical kv cache.
|
26
25
|
"""
|
27
26
|
|
28
27
|
import abc
|
@@ -92,42 +91,73 @@ class ReqToTokenPool:
|
|
92
91
|
self.free_slots = list(range(self.size))
|
93
92
|
|
94
93
|
|
94
|
+
class KVCache(abc.ABC):
|
95
|
+
|
96
|
+
@abc.abstractmethod
|
97
|
+
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
98
|
+
raise NotImplementedError()
|
99
|
+
|
100
|
+
@abc.abstractmethod
|
101
|
+
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
102
|
+
raise NotImplementedError()
|
103
|
+
|
104
|
+
@abc.abstractmethod
|
105
|
+
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
106
|
+
raise NotImplementedError()
|
107
|
+
|
108
|
+
@abc.abstractmethod
|
109
|
+
def set_kv_buffer(
|
110
|
+
self,
|
111
|
+
layer: RadixAttention,
|
112
|
+
loc: torch.Tensor,
|
113
|
+
cache_k: torch.Tensor,
|
114
|
+
cache_v: torch.Tensor,
|
115
|
+
) -> None:
|
116
|
+
raise NotImplementedError()
|
117
|
+
|
118
|
+
|
95
119
|
class TokenToKVPoolAllocator:
|
96
|
-
"""
|
120
|
+
"""An allocator managing the indices to kv cache data."""
|
97
121
|
|
98
122
|
def __init__(
|
99
123
|
self,
|
100
124
|
size: int,
|
101
125
|
dtype: torch.dtype,
|
102
126
|
device: str,
|
127
|
+
kvcache: KVCache,
|
103
128
|
):
|
104
129
|
self.size = size
|
105
130
|
self.dtype = dtype
|
106
131
|
self.device = device
|
132
|
+
self.page_size = 1
|
107
133
|
|
108
134
|
self.free_slots = None
|
109
135
|
self.is_not_in_free_group = True
|
110
136
|
self.free_group = []
|
111
137
|
self.clear()
|
112
138
|
|
139
|
+
self._kvcache = kvcache
|
140
|
+
|
113
141
|
def available_size(self):
|
114
142
|
return len(self.free_slots)
|
115
143
|
|
144
|
+
def get_kvcache(self):
|
145
|
+
return self._kvcache
|
146
|
+
|
116
147
|
def alloc(self, need_size: int):
|
117
148
|
if need_size > len(self.free_slots):
|
118
149
|
return None
|
119
150
|
|
120
151
|
select_index = self.free_slots[:need_size]
|
121
152
|
self.free_slots = self.free_slots[need_size:]
|
122
|
-
|
123
|
-
return select_index.to(self.device, non_blocking=True)
|
153
|
+
return select_index
|
124
154
|
|
125
155
|
def free(self, free_index: torch.Tensor):
|
126
156
|
if free_index.numel() == 0:
|
127
157
|
return
|
128
158
|
|
129
159
|
if self.is_not_in_free_group:
|
130
|
-
self.free_slots = torch.concat((self.free_slots, free_index
|
160
|
+
self.free_slots = torch.concat((self.free_slots, free_index))
|
131
161
|
else:
|
132
162
|
self.free_group.append(free_index)
|
133
163
|
|
@@ -142,41 +172,19 @@ class TokenToKVPoolAllocator:
|
|
142
172
|
|
143
173
|
def clear(self):
|
144
174
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
145
|
-
self.free_slots = torch.arange(
|
175
|
+
self.free_slots = torch.arange(
|
176
|
+
1, self.size + 1, dtype=torch.int64, device=self.device
|
177
|
+
)
|
146
178
|
self.is_in_free_group = False
|
147
179
|
self.free_group = []
|
148
180
|
|
149
181
|
|
150
|
-
class KVCache(abc.ABC):
|
151
|
-
|
152
|
-
@abc.abstractmethod
|
153
|
-
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
154
|
-
raise NotImplementedError()
|
155
|
-
|
156
|
-
@abc.abstractmethod
|
157
|
-
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
158
|
-
raise NotImplementedError()
|
159
|
-
|
160
|
-
@abc.abstractmethod
|
161
|
-
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
162
|
-
raise NotImplementedError()
|
163
|
-
|
164
|
-
@abc.abstractmethod
|
165
|
-
def set_kv_buffer(
|
166
|
-
self,
|
167
|
-
layer: RadixAttention,
|
168
|
-
loc: torch.Tensor,
|
169
|
-
cache_k: torch.Tensor,
|
170
|
-
cache_v: torch.Tensor,
|
171
|
-
) -> None:
|
172
|
-
raise NotImplementedError()
|
173
|
-
|
174
|
-
|
175
182
|
class MHATokenToKVPool(KVCache):
|
176
183
|
|
177
184
|
def __init__(
|
178
185
|
self,
|
179
186
|
size: int,
|
187
|
+
page_size: int,
|
180
188
|
dtype: torch.dtype,
|
181
189
|
head_num: int,
|
182
190
|
head_dim: int,
|
@@ -185,6 +193,7 @@ class MHATokenToKVPool(KVCache):
|
|
185
193
|
enable_memory_saver: bool,
|
186
194
|
):
|
187
195
|
self.size = size
|
196
|
+
self.page_size = page_size
|
188
197
|
self.dtype = dtype
|
189
198
|
self.device = device
|
190
199
|
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
@@ -201,6 +210,10 @@ class MHATokenToKVPool(KVCache):
|
|
201
210
|
self.layer_num = layer_num
|
202
211
|
self._create_buffers()
|
203
212
|
|
213
|
+
self.layer_transfer_counter = None
|
214
|
+
self.capture_mode = False
|
215
|
+
self.alt_stream = torch.cuda.Stream()
|
216
|
+
|
204
217
|
k_size, v_size = self.get_kv_size_bytes()
|
205
218
|
logger.info(
|
206
219
|
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
@@ -211,16 +224,16 @@ class MHATokenToKVPool(KVCache):
|
|
211
224
|
# [size, head_num, head_dim] for each layer
|
212
225
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
213
226
|
self.k_buffer = [
|
214
|
-
torch.
|
215
|
-
(self.size +
|
227
|
+
torch.zeros(
|
228
|
+
(self.size + self.page_size, self.head_num, self.head_dim),
|
216
229
|
dtype=self.store_dtype,
|
217
230
|
device=self.device,
|
218
231
|
)
|
219
232
|
for _ in range(self.layer_num)
|
220
233
|
]
|
221
234
|
self.v_buffer = [
|
222
|
-
torch.
|
223
|
-
(self.size +
|
235
|
+
torch.zeros(
|
236
|
+
(self.size + self.page_size, self.head_num, self.head_dim),
|
224
237
|
dtype=self.store_dtype,
|
225
238
|
device=self.device,
|
226
239
|
)
|
@@ -262,12 +275,28 @@ class MHATokenToKVPool(KVCache):
|
|
262
275
|
self.k_buffer[i][indices] = k_data[i]
|
263
276
|
self.v_buffer[i][indices] = v_data[i]
|
264
277
|
|
278
|
+
def register_layer_transfer_counter(self, layer_transfer_counter):
|
279
|
+
self.layer_transfer_counter = layer_transfer_counter
|
280
|
+
|
281
|
+
def transfer_per_layer(self, indices, flat_data, layer_id):
|
282
|
+
# transfer prepared data from host to device
|
283
|
+
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
284
|
+
k_data, v_data = flat_data[0], flat_data[1]
|
285
|
+
self.k_buffer[layer_id][indices] = k_data
|
286
|
+
self.v_buffer[layer_id][indices] = v_data
|
287
|
+
|
265
288
|
def get_key_buffer(self, layer_id: int):
|
289
|
+
if self.layer_transfer_counter is not None:
|
290
|
+
self.layer_transfer_counter.wait_until(layer_id)
|
291
|
+
|
266
292
|
if self.store_dtype != self.dtype:
|
267
293
|
return self.k_buffer[layer_id].view(self.dtype)
|
268
294
|
return self.k_buffer[layer_id]
|
269
295
|
|
270
296
|
def get_value_buffer(self, layer_id: int):
|
297
|
+
if self.layer_transfer_counter is not None:
|
298
|
+
self.layer_transfer_counter.wait_until(layer_id)
|
299
|
+
|
271
300
|
if self.store_dtype != self.dtype:
|
272
301
|
return self.v_buffer[layer_id].view(self.dtype)
|
273
302
|
return self.v_buffer[layer_id]
|
@@ -292,14 +321,44 @@ class MHATokenToKVPool(KVCache):
|
|
292
321
|
cache_v.div_(v_scale)
|
293
322
|
cache_k = cache_k.to(self.dtype)
|
294
323
|
cache_v = cache_v.to(self.dtype)
|
324
|
+
|
295
325
|
if self.store_dtype != self.dtype:
|
296
|
-
|
297
|
-
|
326
|
+
cache_k = cache_k.view(self.store_dtype)
|
327
|
+
cache_v = cache_v.view(self.store_dtype)
|
328
|
+
|
329
|
+
if self.capture_mode:
|
330
|
+
self.alt_stream.wait_stream(torch.cuda.current_stream())
|
331
|
+
with torch.cuda.stream(self.alt_stream):
|
332
|
+
self.k_buffer[layer_id][loc] = cache_k
|
333
|
+
self.v_buffer[layer_id][loc] = cache_v
|
334
|
+
torch.cuda.current_stream().wait_stream(self.alt_stream)
|
298
335
|
else:
|
299
336
|
self.k_buffer[layer_id][loc] = cache_k
|
300
337
|
self.v_buffer[layer_id][loc] = cache_v
|
301
338
|
|
302
339
|
|
340
|
+
@torch.compile
|
341
|
+
def fused_downcast(
|
342
|
+
cache_k: torch.Tensor,
|
343
|
+
cache_v: torch.Tensor,
|
344
|
+
k_scale: torch.Tensor,
|
345
|
+
v_scale: torch.Tensor,
|
346
|
+
dtype: torch.dtype,
|
347
|
+
store_dtype: torch.dtype,
|
348
|
+
max_fp8: float,
|
349
|
+
min_fp8: float,
|
350
|
+
):
|
351
|
+
cache_k = cache_k / k_scale
|
352
|
+
cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
|
353
|
+
cache_v = cache_v / v_scale
|
354
|
+
cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
|
355
|
+
cache_k = cache_k.to(dtype)
|
356
|
+
cache_v = cache_v.to(dtype)
|
357
|
+
cache_k = cache_k.view(store_dtype)
|
358
|
+
cache_v = cache_v.view(store_dtype)
|
359
|
+
return cache_k, cache_v
|
360
|
+
|
361
|
+
|
303
362
|
# This compiled version is slower in the unit test
|
304
363
|
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
305
364
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
@@ -312,6 +371,7 @@ class MLATokenToKVPool(KVCache):
|
|
312
371
|
def __init__(
|
313
372
|
self,
|
314
373
|
size: int,
|
374
|
+
page_size: int,
|
315
375
|
dtype: torch.dtype,
|
316
376
|
kv_lora_rank: int,
|
317
377
|
qk_rope_head_dim: int,
|
@@ -336,8 +396,8 @@ class MLATokenToKVPool(KVCache):
|
|
336
396
|
with memory_saver_adapter.region():
|
337
397
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
338
398
|
self.kv_buffer = [
|
339
|
-
torch.
|
340
|
-
(size +
|
399
|
+
torch.zeros(
|
400
|
+
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
|
341
401
|
dtype=self.store_dtype,
|
342
402
|
device=device,
|
343
403
|
)
|
@@ -377,6 +437,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
377
437
|
def __init__(
|
378
438
|
self,
|
379
439
|
size: int,
|
440
|
+
page_size: int,
|
380
441
|
dtype: torch.dtype,
|
381
442
|
head_num: int,
|
382
443
|
head_dim: int,
|
@@ -386,6 +447,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
386
447
|
enable_memory_saver: bool,
|
387
448
|
):
|
388
449
|
self.size = size
|
450
|
+
self.page_size = page_size
|
389
451
|
self.dtype = dtype
|
390
452
|
self.device = device
|
391
453
|
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
@@ -400,17 +462,21 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
400
462
|
with memory_saver_adapter.region():
|
401
463
|
# [size, head_num, head_dim] for each layer
|
402
464
|
self.k_buffer = [
|
403
|
-
torch.
|
465
|
+
torch.zeros(
|
466
|
+
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
467
|
+
)
|
404
468
|
for _ in range(layer_num)
|
405
469
|
]
|
406
470
|
self.v_buffer = [
|
407
|
-
torch.
|
471
|
+
torch.zeros(
|
472
|
+
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
473
|
+
)
|
408
474
|
for _ in range(layer_num)
|
409
475
|
]
|
410
476
|
|
411
477
|
# [size, head_num, heavy_channel_num] for each layer
|
412
478
|
self.label_buffer = [
|
413
|
-
torch.
|
479
|
+
torch.zeros(
|
414
480
|
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
415
481
|
)
|
416
482
|
for _ in range(layer_num)
|
@@ -465,7 +531,7 @@ class MHATokenToKVPoolHost:
|
|
465
531
|
def __init__(
|
466
532
|
self,
|
467
533
|
device_pool: MHATokenToKVPool,
|
468
|
-
host_to_device_ratio: float =
|
534
|
+
host_to_device_ratio: float = 3.0,
|
469
535
|
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
470
536
|
device: str = "cpu",
|
471
537
|
):
|
@@ -505,7 +571,7 @@ class MHATokenToKVPoolHost:
|
|
505
571
|
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
506
572
|
)
|
507
573
|
|
508
|
-
self.kv_buffer = torch.
|
574
|
+
self.kv_buffer = torch.zeros(
|
509
575
|
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
510
576
|
dtype=self.dtype,
|
511
577
|
device=self.device,
|