sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,380 @@
|
|
1
|
+
import abc
|
2
|
+
import logging
|
3
|
+
import threading
|
4
|
+
from enum import IntEnum
|
5
|
+
from functools import wraps
|
6
|
+
|
7
|
+
import psutil
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
11
|
+
from sglang.srt.utils import debug_timing
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class MemoryStateInt(IntEnum):
|
17
|
+
IDLE = 0
|
18
|
+
RESERVED = 1
|
19
|
+
PROTECTED = 2
|
20
|
+
SYNCED = 3
|
21
|
+
BACKUP = 4
|
22
|
+
|
23
|
+
|
24
|
+
def synchronized(debug_only=False):
|
25
|
+
def _decorator(func):
|
26
|
+
@wraps(func)
|
27
|
+
def wrapper(self, *args, **kwargs):
|
28
|
+
if (not debug_only) or self.debug:
|
29
|
+
return func(self, *args, **kwargs)
|
30
|
+
with self.lock:
|
31
|
+
return func(self, *args, **kwargs)
|
32
|
+
else:
|
33
|
+
return True
|
34
|
+
|
35
|
+
return wrapper
|
36
|
+
|
37
|
+
return _decorator
|
38
|
+
|
39
|
+
|
40
|
+
class HostKVCache(abc.ABC):
|
41
|
+
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
device_pool: KVCache,
|
45
|
+
host_to_device_ratio: float,
|
46
|
+
host_size: int,
|
47
|
+
pin_memory: bool,
|
48
|
+
device: str,
|
49
|
+
page_size: int,
|
50
|
+
):
|
51
|
+
self.device_pool = device_pool
|
52
|
+
self.dtype = device_pool.store_dtype
|
53
|
+
self.pin_memory = pin_memory
|
54
|
+
self.device = device
|
55
|
+
self.page_size = page_size
|
56
|
+
self.size_per_token = self.get_size_per_token()
|
57
|
+
if host_size > 0:
|
58
|
+
self.size = int(host_size * 1e9 // self.size_per_token)
|
59
|
+
else:
|
60
|
+
self.size = int(device_pool.size * host_to_device_ratio)
|
61
|
+
# Align the host memory pool size to the page size
|
62
|
+
self.size = self.size - (self.size % self.page_size)
|
63
|
+
self.start_layer = device_pool.start_layer
|
64
|
+
self.end_layer = device_pool.end_layer
|
65
|
+
|
66
|
+
assert (
|
67
|
+
self.size > device_pool.size
|
68
|
+
), "The host memory should be larger than the device memory with the current protocol"
|
69
|
+
|
70
|
+
# Verify there is enough available host memory.
|
71
|
+
host_mem = psutil.virtual_memory()
|
72
|
+
requested_bytes = self.size * self.size_per_token
|
73
|
+
# preserve at least 10GB for other usage
|
74
|
+
ten_gb = 10 * (1024**3)
|
75
|
+
if requested_bytes > host_mem.available - ten_gb:
|
76
|
+
raise ValueError(
|
77
|
+
f"Not enough host memory available. Requesting "
|
78
|
+
f"{requested_bytes / 1e9:.2f} GB but only have "
|
79
|
+
f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
|
80
|
+
f"size of the hierarchical cache."
|
81
|
+
)
|
82
|
+
else:
|
83
|
+
logger.info(
|
84
|
+
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
85
|
+
)
|
86
|
+
|
87
|
+
self.kv_buffer = self.init_kv_buffer()
|
88
|
+
|
89
|
+
# A lock for synchronized operations on memory allocation and state transitions.
|
90
|
+
self.lock = threading.RLock()
|
91
|
+
self.debug = logger.isEnabledFor(logging.DEBUG)
|
92
|
+
self.clear()
|
93
|
+
|
94
|
+
@abc.abstractmethod
|
95
|
+
def get_size_per_token(self):
|
96
|
+
raise NotImplementedError()
|
97
|
+
|
98
|
+
@abc.abstractmethod
|
99
|
+
def init_kv_buffer(self):
|
100
|
+
raise NotImplementedError()
|
101
|
+
|
102
|
+
@abc.abstractmethod
|
103
|
+
def transfer(self, indices, flat_data):
|
104
|
+
raise NotImplementedError()
|
105
|
+
|
106
|
+
@abc.abstractmethod
|
107
|
+
def get_flat_data(self, indices):
|
108
|
+
raise NotImplementedError()
|
109
|
+
|
110
|
+
@abc.abstractmethod
|
111
|
+
def get_flat_data_by_layer(self, indices, layer_id):
|
112
|
+
raise NotImplementedError()
|
113
|
+
|
114
|
+
@abc.abstractmethod
|
115
|
+
def assign_flat_data(self, indices, flat_data):
|
116
|
+
raise NotImplementedError()
|
117
|
+
|
118
|
+
@synchronized()
|
119
|
+
def clear(self):
|
120
|
+
# Initialize memory states and tracking structures.
|
121
|
+
self.mem_state = torch.zeros(
|
122
|
+
(self.size,), dtype=torch.uint8, device=self.device
|
123
|
+
)
|
124
|
+
self.free_slots = torch.arange(self.size, dtype=torch.int64)
|
125
|
+
|
126
|
+
def available_size(self):
|
127
|
+
return len(self.free_slots)
|
128
|
+
|
129
|
+
@synchronized()
|
130
|
+
def alloc(self, need_size: int) -> torch.Tensor:
|
131
|
+
if need_size > self.available_size():
|
132
|
+
return None
|
133
|
+
|
134
|
+
select_index = self.free_slots[:need_size]
|
135
|
+
self.free_slots = self.free_slots[need_size:]
|
136
|
+
|
137
|
+
if self.debug:
|
138
|
+
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
139
|
+
|
140
|
+
return select_index
|
141
|
+
|
142
|
+
@synchronized()
|
143
|
+
def free(self, indices: torch.Tensor) -> int:
|
144
|
+
self.free_slots = torch.cat([self.free_slots, indices])
|
145
|
+
if self.debug:
|
146
|
+
self.mem_state[indices] = MemoryStateInt.IDLE
|
147
|
+
return len(indices)
|
148
|
+
|
149
|
+
@synchronized(debug_only=True)
|
150
|
+
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
151
|
+
assert len(indices) > 0, "The indices should not be empty"
|
152
|
+
states = self.mem_state[indices]
|
153
|
+
assert (
|
154
|
+
states == states[0]
|
155
|
+
).all(), "The memory slots should have the same state {}".format(states)
|
156
|
+
return MemoryStateInt(states[0].item())
|
157
|
+
|
158
|
+
@synchronized(debug_only=True)
|
159
|
+
def is_reserved(self, indices: torch.Tensor) -> bool:
|
160
|
+
return self.get_state(indices) == MemoryStateInt.RESERVED
|
161
|
+
|
162
|
+
@synchronized(debug_only=True)
|
163
|
+
def is_protected(self, indices: torch.Tensor) -> bool:
|
164
|
+
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
165
|
+
|
166
|
+
@synchronized(debug_only=True)
|
167
|
+
def is_synced(self, indices: torch.Tensor) -> bool:
|
168
|
+
return self.get_state(indices) == MemoryStateInt.SYNCED
|
169
|
+
|
170
|
+
@synchronized(debug_only=True)
|
171
|
+
def is_backup(self, indices: torch.Tensor) -> bool:
|
172
|
+
return self.get_state(indices) == MemoryStateInt.BACKUP
|
173
|
+
|
174
|
+
@synchronized(debug_only=True)
|
175
|
+
def update_backup(self, indices: torch.Tensor):
|
176
|
+
if not self.is_synced(indices):
|
177
|
+
raise ValueError(
|
178
|
+
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
179
|
+
f"Current state: {self.get_state(indices)}"
|
180
|
+
)
|
181
|
+
self.mem_state[indices] = MemoryStateInt.BACKUP
|
182
|
+
|
183
|
+
@synchronized(debug_only=True)
|
184
|
+
def update_synced(self, indices: torch.Tensor):
|
185
|
+
self.mem_state[indices] = MemoryStateInt.SYNCED
|
186
|
+
|
187
|
+
@synchronized(debug_only=True)
|
188
|
+
def protect_write(self, indices: torch.Tensor):
|
189
|
+
if not self.is_reserved(indices):
|
190
|
+
raise ValueError(
|
191
|
+
f"The host memory slots should be RESERVED before write operations. "
|
192
|
+
f"Current state: {self.get_state(indices)}"
|
193
|
+
)
|
194
|
+
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
195
|
+
|
196
|
+
@synchronized(debug_only=True)
|
197
|
+
def protect_load(self, indices: torch.Tensor):
|
198
|
+
if not self.is_backup(indices):
|
199
|
+
raise ValueError(
|
200
|
+
f"The host memory slots should be in BACKUP state before load operations. "
|
201
|
+
f"Current state: {self.get_state(indices)}"
|
202
|
+
)
|
203
|
+
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
204
|
+
|
205
|
+
@synchronized(debug_only=True)
|
206
|
+
def complete_io(self, indices: torch.Tensor):
|
207
|
+
if not self.is_protected(indices):
|
208
|
+
raise ValueError(
|
209
|
+
f"The host memory slots should be PROTECTED during I/O operations. "
|
210
|
+
f"Current state: {self.get_state(indices)}"
|
211
|
+
)
|
212
|
+
self.mem_state[indices] = MemoryStateInt.SYNCED
|
213
|
+
|
214
|
+
|
215
|
+
class MHATokenToKVPoolHost(HostKVCache):
|
216
|
+
device_pool: MHATokenToKVPool
|
217
|
+
|
218
|
+
def __init__(
|
219
|
+
self,
|
220
|
+
device_pool: MHATokenToKVPool,
|
221
|
+
host_to_device_ratio: float,
|
222
|
+
host_size: int,
|
223
|
+
page_size: int,
|
224
|
+
pin_memory: bool = True,
|
225
|
+
device: str = "cpu",
|
226
|
+
):
|
227
|
+
super().__init__(
|
228
|
+
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
229
|
+
)
|
230
|
+
|
231
|
+
def get_size_per_token(self):
|
232
|
+
self.head_num = self.device_pool.head_num
|
233
|
+
self.head_dim = self.device_pool.head_dim
|
234
|
+
self.layer_num = self.device_pool.layer_num
|
235
|
+
|
236
|
+
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
237
|
+
|
238
|
+
def init_kv_buffer(self):
|
239
|
+
return torch.empty(
|
240
|
+
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
241
|
+
dtype=self.dtype,
|
242
|
+
device=self.device,
|
243
|
+
pin_memory=self.pin_memory,
|
244
|
+
)
|
245
|
+
|
246
|
+
@debug_timing
|
247
|
+
def transfer(self, indices, flat_data):
|
248
|
+
# backup prepared data from device to host
|
249
|
+
self.kv_buffer[:, :, indices] = flat_data.to(
|
250
|
+
device=self.device, non_blocking=False
|
251
|
+
)
|
252
|
+
|
253
|
+
def get_flat_data(self, indices):
|
254
|
+
return self.kv_buffer[:, :, indices]
|
255
|
+
|
256
|
+
def get_flat_data_by_layer(self, indices, layer_id):
|
257
|
+
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
258
|
+
|
259
|
+
def assign_flat_data(self, indices, flat_data):
|
260
|
+
self.kv_buffer[:, :, indices] = flat_data
|
261
|
+
|
262
|
+
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
263
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
264
|
+
for i in range(len(device_indices_cpu)):
|
265
|
+
h_index = host_indices[i * self.page_size]
|
266
|
+
d_index = device_indices_cpu[i]
|
267
|
+
for j in range(self.layer_num):
|
268
|
+
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
269
|
+
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
270
|
+
non_blocking=True,
|
271
|
+
)
|
272
|
+
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
273
|
+
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
274
|
+
non_blocking=True,
|
275
|
+
)
|
276
|
+
|
277
|
+
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
278
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
279
|
+
for i in range(len(device_indices_cpu)):
|
280
|
+
h_index = host_indices[i * self.page_size]
|
281
|
+
d_index = device_indices_cpu[i]
|
282
|
+
device_pool.k_buffer[layer_id - self.start_layer][
|
283
|
+
d_index : d_index + self.page_size
|
284
|
+
].copy_(
|
285
|
+
self.kv_buffer[
|
286
|
+
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
287
|
+
],
|
288
|
+
non_blocking=True,
|
289
|
+
)
|
290
|
+
device_pool.v_buffer[layer_id - self.start_layer][
|
291
|
+
d_index : d_index + self.page_size
|
292
|
+
].copy_(
|
293
|
+
self.kv_buffer[
|
294
|
+
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
295
|
+
],
|
296
|
+
non_blocking=True,
|
297
|
+
)
|
298
|
+
|
299
|
+
|
300
|
+
class MLATokenToKVPoolHost(HostKVCache):
|
301
|
+
device_pool: MLATokenToKVPool
|
302
|
+
|
303
|
+
def __init__(
|
304
|
+
self,
|
305
|
+
device_pool: MLATokenToKVPool,
|
306
|
+
host_to_device_ratio: float,
|
307
|
+
host_size: int,
|
308
|
+
page_size: int,
|
309
|
+
pin_memory: bool = True,
|
310
|
+
device: str = "cpu",
|
311
|
+
):
|
312
|
+
super().__init__(
|
313
|
+
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
314
|
+
)
|
315
|
+
|
316
|
+
def get_size_per_token(self):
|
317
|
+
self.kv_lora_rank = self.device_pool.kv_lora_rank
|
318
|
+
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
319
|
+
self.layer_num = self.device_pool.layer_num
|
320
|
+
|
321
|
+
return (
|
322
|
+
(self.kv_lora_rank + self.qk_rope_head_dim)
|
323
|
+
* 1
|
324
|
+
* self.dtype.itemsize
|
325
|
+
* self.layer_num
|
326
|
+
)
|
327
|
+
|
328
|
+
def init_kv_buffer(self):
|
329
|
+
return torch.empty(
|
330
|
+
(
|
331
|
+
self.layer_num,
|
332
|
+
self.size,
|
333
|
+
1,
|
334
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
335
|
+
),
|
336
|
+
dtype=self.dtype,
|
337
|
+
device=self.device,
|
338
|
+
pin_memory=self.pin_memory,
|
339
|
+
)
|
340
|
+
|
341
|
+
@debug_timing
|
342
|
+
def transfer(self, indices, flat_data):
|
343
|
+
# backup prepared data from device to host
|
344
|
+
self.kv_buffer[:, indices] = flat_data.to(
|
345
|
+
device=self.device, non_blocking=False
|
346
|
+
)
|
347
|
+
|
348
|
+
def get_flat_data(self, indices):
|
349
|
+
return self.kv_buffer[:, indices]
|
350
|
+
|
351
|
+
def get_flat_data_by_layer(self, indices, layer_id):
|
352
|
+
return self.kv_buffer[layer_id - self.start_layer, indices]
|
353
|
+
|
354
|
+
def assign_flat_data(self, indices, flat_data):
|
355
|
+
self.kv_buffer[:, indices] = flat_data
|
356
|
+
|
357
|
+
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
358
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
359
|
+
for i in range(len(device_indices_cpu)):
|
360
|
+
h_index = host_indices[i * self.page_size]
|
361
|
+
d_index = device_indices_cpu[i]
|
362
|
+
for j in range(self.layer_num):
|
363
|
+
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
364
|
+
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
365
|
+
non_blocking=True,
|
366
|
+
)
|
367
|
+
|
368
|
+
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
369
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
370
|
+
for i in range(len(device_indices_cpu)):
|
371
|
+
h_index = host_indices[i * self.page_size]
|
372
|
+
d_index = device_indices_cpu[i]
|
373
|
+
device_pool.kv_buffer[layer_id - self.start_layer][
|
374
|
+
d_index : d_index + self.page_size
|
375
|
+
].copy_(
|
376
|
+
self.kv_buffer[
|
377
|
+
layer_id - self.start_layer, h_index : h_index + self.page_size
|
378
|
+
],
|
379
|
+
non_blocking=True,
|
380
|
+
)
|
@@ -23,7 +23,7 @@ import heapq
|
|
23
23
|
import time
|
24
24
|
from collections import defaultdict
|
25
25
|
from functools import partial
|
26
|
-
from typing import TYPE_CHECKING, List, Optional
|
26
|
+
from typing import TYPE_CHECKING, List, Optional
|
27
27
|
|
28
28
|
import torch
|
29
29
|
|
@@ -31,11 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
|
|
31
31
|
AllBlocksCleared,
|
32
32
|
BlockRemoved,
|
33
33
|
BlockStored,
|
34
|
-
KVCacheEvent,
|
35
34
|
)
|
36
|
-
from sglang.srt.
|
37
|
-
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
38
|
-
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
35
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
36
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
37
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
39
38
|
|
40
39
|
if TYPE_CHECKING:
|
41
40
|
from sglang.srt.managers.schedule_batch import Req
|
@@ -47,9 +46,9 @@ class TreeNode:
|
|
47
46
|
|
48
47
|
def __init__(self, id: Optional[int] = None):
|
49
48
|
self.children = defaultdict(TreeNode)
|
50
|
-
self.parent = None
|
51
|
-
self.key = None
|
52
|
-
self.value = None
|
49
|
+
self.parent: TreeNode = None
|
50
|
+
self.key: List[int] = None
|
51
|
+
self.value: Optional[torch.Tensor] = None
|
53
52
|
self.lock_ref = 0
|
54
53
|
self.last_access_time = time.monotonic()
|
55
54
|
|
@@ -57,7 +56,7 @@ class TreeNode:
|
|
57
56
|
# indicating the node is loading KV cache from host
|
58
57
|
self.loading = False
|
59
58
|
# store the host indices of KV cache
|
60
|
-
self.host_value = None
|
59
|
+
self.host_value: Optional[torch.Tensor] = None
|
61
60
|
|
62
61
|
self.id = TreeNode.counter if id is None else id
|
63
62
|
TreeNode.counter += 1
|
@@ -99,7 +98,7 @@ class RadixCache(BasePrefixCache):
|
|
99
98
|
def __init__(
|
100
99
|
self,
|
101
100
|
req_to_token_pool: ReqToTokenPool,
|
102
|
-
token_to_kv_pool_allocator:
|
101
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
103
102
|
page_size: int,
|
104
103
|
disable: bool = False,
|
105
104
|
enable_kv_cache_events: bool = False,
|
@@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache):
|
|
135
134
|
self.protected_size_ = 0
|
136
135
|
self._record_all_cleared_event()
|
137
136
|
|
138
|
-
def match_prefix(self, key: List[int], **kwargs) ->
|
137
|
+
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
139
138
|
"""Find the matching prefix from the radix tree.
|
140
139
|
Args:
|
141
140
|
key: A list of token IDs to find a matching prefix.
|
@@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache):
|
|
147
146
|
than the last node's value.
|
148
147
|
"""
|
149
148
|
if self.disable or len(key) == 0:
|
150
|
-
return (
|
151
|
-
torch.empty(
|
149
|
+
return MatchResult(
|
150
|
+
device_indices=torch.empty(
|
152
151
|
(0,),
|
153
152
|
dtype=torch.int64,
|
154
153
|
device=self.device,
|
155
154
|
),
|
156
|
-
self.root_node,
|
155
|
+
last_device_node=self.root_node,
|
156
|
+
last_host_node=self.root_node,
|
157
157
|
)
|
158
158
|
|
159
159
|
if self.page_size != 1:
|
@@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache):
|
|
165
165
|
value = torch.cat(value)
|
166
166
|
else:
|
167
167
|
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
168
|
-
return
|
168
|
+
return MatchResult(
|
169
|
+
device_indices=value,
|
170
|
+
last_device_node=last_node,
|
171
|
+
last_host_node=last_node,
|
172
|
+
)
|
169
173
|
|
170
174
|
def insert(self, key: List, value=None):
|
171
175
|
if self.disable:
|
@@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache):
|
|
235
239
|
)
|
236
240
|
|
237
241
|
# The prefix indices could be updated, reuse it
|
238
|
-
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
|
242
|
+
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
|
239
243
|
self.req_to_token_pool.write(
|
240
244
|
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
241
245
|
new_indices[len(req.prefix_indices) :],
|
@@ -461,23 +465,47 @@ class RadixCache(BasePrefixCache):
|
|
461
465
|
return ret_list
|
462
466
|
|
463
467
|
def _record_store_event(self, node: TreeNode):
|
468
|
+
# One BlockStored per ``page_size`` chunk.
|
464
469
|
if self.enable_kv_cache_events:
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
470
|
+
# First chunk links to the last page of the parent node (if any).
|
471
|
+
if node.parent is None:
|
472
|
+
parent_block_hash = None
|
473
|
+
else:
|
474
|
+
last_page_start = (
|
475
|
+
(len(node.parent.key) - 1) // self.page_size
|
476
|
+
) * self.page_size
|
477
|
+
parent_parent_tokens = node.parent.key[last_page_start:]
|
478
|
+
parent_block_hash = hash(tuple(parent_parent_tokens))
|
479
|
+
|
480
|
+
for start in range(0, len(node.key), self.page_size):
|
481
|
+
page_tokens = node.key[start : start + self.page_size]
|
482
|
+
if not page_tokens:
|
483
|
+
continue
|
484
|
+
|
485
|
+
block_hash = hash(tuple(page_tokens))
|
486
|
+
|
487
|
+
self.kv_event_queue.append(
|
488
|
+
BlockStored(
|
489
|
+
block_hashes=[block_hash],
|
490
|
+
parent_block_hash=parent_block_hash,
|
491
|
+
token_ids=page_tokens,
|
492
|
+
block_size=len(page_tokens),
|
493
|
+
lora_id=None,
|
494
|
+
)
|
474
495
|
)
|
475
|
-
|
496
|
+
|
497
|
+
# Chain next chunk to this one.
|
498
|
+
parent_block_hash = block_hash
|
476
499
|
|
477
500
|
def _record_remove_event(self, node: TreeNode):
|
501
|
+
# One BlockRemoved per chunk.
|
478
502
|
if self.enable_kv_cache_events:
|
479
|
-
|
480
|
-
|
503
|
+
for start in range(0, len(node.key), self.page_size):
|
504
|
+
page_tokens = node.key[start : start + self.page_size]
|
505
|
+
if not page_tokens:
|
506
|
+
continue
|
507
|
+
block_hash = hash(tuple(page_tokens))
|
508
|
+
self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
|
481
509
|
|
482
510
|
def _record_all_cleared_event(self):
|
483
511
|
if self.enable_kv_cache_events:
|