sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,380 @@
|
|
1
|
+
import abc
|
2
|
+
import logging
|
3
|
+
import threading
|
4
|
+
from enum import IntEnum
|
5
|
+
from functools import wraps
|
6
|
+
|
7
|
+
import psutil
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
11
|
+
from sglang.srt.utils import debug_timing
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class MemoryStateInt(IntEnum):
|
17
|
+
IDLE = 0
|
18
|
+
RESERVED = 1
|
19
|
+
PROTECTED = 2
|
20
|
+
SYNCED = 3
|
21
|
+
BACKUP = 4
|
22
|
+
|
23
|
+
|
24
|
+
def synchronized(debug_only=False):
|
25
|
+
def _decorator(func):
|
26
|
+
@wraps(func)
|
27
|
+
def wrapper(self, *args, **kwargs):
|
28
|
+
if (not debug_only) or self.debug:
|
29
|
+
return func(self, *args, **kwargs)
|
30
|
+
with self.lock:
|
31
|
+
return func(self, *args, **kwargs)
|
32
|
+
else:
|
33
|
+
return True
|
34
|
+
|
35
|
+
return wrapper
|
36
|
+
|
37
|
+
return _decorator
|
38
|
+
|
39
|
+
|
40
|
+
class HostKVCache(abc.ABC):
|
41
|
+
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
device_pool: KVCache,
|
45
|
+
host_to_device_ratio: float,
|
46
|
+
host_size: int,
|
47
|
+
pin_memory: bool,
|
48
|
+
device: str,
|
49
|
+
page_size: int,
|
50
|
+
):
|
51
|
+
self.device_pool = device_pool
|
52
|
+
self.dtype = device_pool.store_dtype
|
53
|
+
self.pin_memory = pin_memory
|
54
|
+
self.device = device
|
55
|
+
self.page_size = page_size
|
56
|
+
self.size_per_token = self.get_size_per_token()
|
57
|
+
if host_size > 0:
|
58
|
+
self.size = int(host_size * 1e9 // self.size_per_token)
|
59
|
+
else:
|
60
|
+
self.size = int(device_pool.size * host_to_device_ratio)
|
61
|
+
# Align the host memory pool size to the page size
|
62
|
+
self.size = self.size - (self.size % self.page_size)
|
63
|
+
self.start_layer = device_pool.start_layer
|
64
|
+
self.end_layer = device_pool.end_layer
|
65
|
+
|
66
|
+
assert (
|
67
|
+
self.size > device_pool.size
|
68
|
+
), "The host memory should be larger than the device memory with the current protocol"
|
69
|
+
|
70
|
+
# Verify there is enough available host memory.
|
71
|
+
host_mem = psutil.virtual_memory()
|
72
|
+
requested_bytes = self.size * self.size_per_token
|
73
|
+
# preserve at least 10GB for other usage
|
74
|
+
ten_gb = 10 * (1024**3)
|
75
|
+
if requested_bytes > host_mem.available - ten_gb:
|
76
|
+
raise ValueError(
|
77
|
+
f"Not enough host memory available. Requesting "
|
78
|
+
f"{requested_bytes / 1e9:.2f} GB but only have "
|
79
|
+
f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
|
80
|
+
f"size of the hierarchical cache."
|
81
|
+
)
|
82
|
+
else:
|
83
|
+
logger.info(
|
84
|
+
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
85
|
+
)
|
86
|
+
|
87
|
+
self.kv_buffer = self.init_kv_buffer()
|
88
|
+
|
89
|
+
# A lock for synchronized operations on memory allocation and state transitions.
|
90
|
+
self.lock = threading.RLock()
|
91
|
+
self.debug = logger.isEnabledFor(logging.DEBUG)
|
92
|
+
self.clear()
|
93
|
+
|
94
|
+
@abc.abstractmethod
|
95
|
+
def get_size_per_token(self):
|
96
|
+
raise NotImplementedError()
|
97
|
+
|
98
|
+
@abc.abstractmethod
|
99
|
+
def init_kv_buffer(self):
|
100
|
+
raise NotImplementedError()
|
101
|
+
|
102
|
+
@abc.abstractmethod
|
103
|
+
def transfer(self, indices, flat_data):
|
104
|
+
raise NotImplementedError()
|
105
|
+
|
106
|
+
@abc.abstractmethod
|
107
|
+
def get_flat_data(self, indices):
|
108
|
+
raise NotImplementedError()
|
109
|
+
|
110
|
+
@abc.abstractmethod
|
111
|
+
def get_flat_data_by_layer(self, indices, layer_id):
|
112
|
+
raise NotImplementedError()
|
113
|
+
|
114
|
+
@abc.abstractmethod
|
115
|
+
def assign_flat_data(self, indices, flat_data):
|
116
|
+
raise NotImplementedError()
|
117
|
+
|
118
|
+
@synchronized()
|
119
|
+
def clear(self):
|
120
|
+
# Initialize memory states and tracking structures.
|
121
|
+
self.mem_state = torch.zeros(
|
122
|
+
(self.size,), dtype=torch.uint8, device=self.device
|
123
|
+
)
|
124
|
+
self.free_slots = torch.arange(self.size, dtype=torch.int64)
|
125
|
+
|
126
|
+
def available_size(self):
|
127
|
+
return len(self.free_slots)
|
128
|
+
|
129
|
+
@synchronized()
|
130
|
+
def alloc(self, need_size: int) -> torch.Tensor:
|
131
|
+
if need_size > self.available_size():
|
132
|
+
return None
|
133
|
+
|
134
|
+
select_index = self.free_slots[:need_size]
|
135
|
+
self.free_slots = self.free_slots[need_size:]
|
136
|
+
|
137
|
+
if self.debug:
|
138
|
+
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
139
|
+
|
140
|
+
return select_index
|
141
|
+
|
142
|
+
@synchronized()
|
143
|
+
def free(self, indices: torch.Tensor) -> int:
|
144
|
+
self.free_slots = torch.cat([self.free_slots, indices])
|
145
|
+
if self.debug:
|
146
|
+
self.mem_state[indices] = MemoryStateInt.IDLE
|
147
|
+
return len(indices)
|
148
|
+
|
149
|
+
@synchronized(debug_only=True)
|
150
|
+
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
151
|
+
assert len(indices) > 0, "The indices should not be empty"
|
152
|
+
states = self.mem_state[indices]
|
153
|
+
assert (
|
154
|
+
states == states[0]
|
155
|
+
).all(), "The memory slots should have the same state {}".format(states)
|
156
|
+
return MemoryStateInt(states[0].item())
|
157
|
+
|
158
|
+
@synchronized(debug_only=True)
|
159
|
+
def is_reserved(self, indices: torch.Tensor) -> bool:
|
160
|
+
return self.get_state(indices) == MemoryStateInt.RESERVED
|
161
|
+
|
162
|
+
@synchronized(debug_only=True)
|
163
|
+
def is_protected(self, indices: torch.Tensor) -> bool:
|
164
|
+
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
165
|
+
|
166
|
+
@synchronized(debug_only=True)
|
167
|
+
def is_synced(self, indices: torch.Tensor) -> bool:
|
168
|
+
return self.get_state(indices) == MemoryStateInt.SYNCED
|
169
|
+
|
170
|
+
@synchronized(debug_only=True)
|
171
|
+
def is_backup(self, indices: torch.Tensor) -> bool:
|
172
|
+
return self.get_state(indices) == MemoryStateInt.BACKUP
|
173
|
+
|
174
|
+
@synchronized(debug_only=True)
|
175
|
+
def update_backup(self, indices: torch.Tensor):
|
176
|
+
if not self.is_synced(indices):
|
177
|
+
raise ValueError(
|
178
|
+
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
179
|
+
f"Current state: {self.get_state(indices)}"
|
180
|
+
)
|
181
|
+
self.mem_state[indices] = MemoryStateInt.BACKUP
|
182
|
+
|
183
|
+
@synchronized(debug_only=True)
|
184
|
+
def update_synced(self, indices: torch.Tensor):
|
185
|
+
self.mem_state[indices] = MemoryStateInt.SYNCED
|
186
|
+
|
187
|
+
@synchronized(debug_only=True)
|
188
|
+
def protect_write(self, indices: torch.Tensor):
|
189
|
+
if not self.is_reserved(indices):
|
190
|
+
raise ValueError(
|
191
|
+
f"The host memory slots should be RESERVED before write operations. "
|
192
|
+
f"Current state: {self.get_state(indices)}"
|
193
|
+
)
|
194
|
+
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
195
|
+
|
196
|
+
@synchronized(debug_only=True)
|
197
|
+
def protect_load(self, indices: torch.Tensor):
|
198
|
+
if not self.is_backup(indices):
|
199
|
+
raise ValueError(
|
200
|
+
f"The host memory slots should be in BACKUP state before load operations. "
|
201
|
+
f"Current state: {self.get_state(indices)}"
|
202
|
+
)
|
203
|
+
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
204
|
+
|
205
|
+
@synchronized(debug_only=True)
|
206
|
+
def complete_io(self, indices: torch.Tensor):
|
207
|
+
if not self.is_protected(indices):
|
208
|
+
raise ValueError(
|
209
|
+
f"The host memory slots should be PROTECTED during I/O operations. "
|
210
|
+
f"Current state: {self.get_state(indices)}"
|
211
|
+
)
|
212
|
+
self.mem_state[indices] = MemoryStateInt.SYNCED
|
213
|
+
|
214
|
+
|
215
|
+
class MHATokenToKVPoolHost(HostKVCache):
|
216
|
+
device_pool: MHATokenToKVPool
|
217
|
+
|
218
|
+
def __init__(
|
219
|
+
self,
|
220
|
+
device_pool: MHATokenToKVPool,
|
221
|
+
host_to_device_ratio: float,
|
222
|
+
host_size: int,
|
223
|
+
page_size: int,
|
224
|
+
pin_memory: bool = True,
|
225
|
+
device: str = "cpu",
|
226
|
+
):
|
227
|
+
super().__init__(
|
228
|
+
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
229
|
+
)
|
230
|
+
|
231
|
+
def get_size_per_token(self):
|
232
|
+
self.head_num = self.device_pool.head_num
|
233
|
+
self.head_dim = self.device_pool.head_dim
|
234
|
+
self.layer_num = self.device_pool.layer_num
|
235
|
+
|
236
|
+
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
237
|
+
|
238
|
+
def init_kv_buffer(self):
|
239
|
+
return torch.empty(
|
240
|
+
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
241
|
+
dtype=self.dtype,
|
242
|
+
device=self.device,
|
243
|
+
pin_memory=self.pin_memory,
|
244
|
+
)
|
245
|
+
|
246
|
+
@debug_timing
|
247
|
+
def transfer(self, indices, flat_data):
|
248
|
+
# backup prepared data from device to host
|
249
|
+
self.kv_buffer[:, :, indices] = flat_data.to(
|
250
|
+
device=self.device, non_blocking=False
|
251
|
+
)
|
252
|
+
|
253
|
+
def get_flat_data(self, indices):
|
254
|
+
return self.kv_buffer[:, :, indices]
|
255
|
+
|
256
|
+
def get_flat_data_by_layer(self, indices, layer_id):
|
257
|
+
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
258
|
+
|
259
|
+
def assign_flat_data(self, indices, flat_data):
|
260
|
+
self.kv_buffer[:, :, indices] = flat_data
|
261
|
+
|
262
|
+
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
263
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
264
|
+
for i in range(len(device_indices_cpu)):
|
265
|
+
h_index = host_indices[i * self.page_size]
|
266
|
+
d_index = device_indices_cpu[i]
|
267
|
+
for j in range(self.layer_num):
|
268
|
+
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
269
|
+
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
270
|
+
non_blocking=True,
|
271
|
+
)
|
272
|
+
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
273
|
+
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
274
|
+
non_blocking=True,
|
275
|
+
)
|
276
|
+
|
277
|
+
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
278
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
279
|
+
for i in range(len(device_indices_cpu)):
|
280
|
+
h_index = host_indices[i * self.page_size]
|
281
|
+
d_index = device_indices_cpu[i]
|
282
|
+
device_pool.k_buffer[layer_id - self.start_layer][
|
283
|
+
d_index : d_index + self.page_size
|
284
|
+
].copy_(
|
285
|
+
self.kv_buffer[
|
286
|
+
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
287
|
+
],
|
288
|
+
non_blocking=True,
|
289
|
+
)
|
290
|
+
device_pool.v_buffer[layer_id - self.start_layer][
|
291
|
+
d_index : d_index + self.page_size
|
292
|
+
].copy_(
|
293
|
+
self.kv_buffer[
|
294
|
+
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
295
|
+
],
|
296
|
+
non_blocking=True,
|
297
|
+
)
|
298
|
+
|
299
|
+
|
300
|
+
class MLATokenToKVPoolHost(HostKVCache):
|
301
|
+
device_pool: MLATokenToKVPool
|
302
|
+
|
303
|
+
def __init__(
|
304
|
+
self,
|
305
|
+
device_pool: MLATokenToKVPool,
|
306
|
+
host_to_device_ratio: float,
|
307
|
+
host_size: int,
|
308
|
+
page_size: int,
|
309
|
+
pin_memory: bool = True,
|
310
|
+
device: str = "cpu",
|
311
|
+
):
|
312
|
+
super().__init__(
|
313
|
+
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
314
|
+
)
|
315
|
+
|
316
|
+
def get_size_per_token(self):
|
317
|
+
self.kv_lora_rank = self.device_pool.kv_lora_rank
|
318
|
+
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
319
|
+
self.layer_num = self.device_pool.layer_num
|
320
|
+
|
321
|
+
return (
|
322
|
+
(self.kv_lora_rank + self.qk_rope_head_dim)
|
323
|
+
* 1
|
324
|
+
* self.dtype.itemsize
|
325
|
+
* self.layer_num
|
326
|
+
)
|
327
|
+
|
328
|
+
def init_kv_buffer(self):
|
329
|
+
return torch.empty(
|
330
|
+
(
|
331
|
+
self.layer_num,
|
332
|
+
self.size,
|
333
|
+
1,
|
334
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
335
|
+
),
|
336
|
+
dtype=self.dtype,
|
337
|
+
device=self.device,
|
338
|
+
pin_memory=self.pin_memory,
|
339
|
+
)
|
340
|
+
|
341
|
+
@debug_timing
|
342
|
+
def transfer(self, indices, flat_data):
|
343
|
+
# backup prepared data from device to host
|
344
|
+
self.kv_buffer[:, indices] = flat_data.to(
|
345
|
+
device=self.device, non_blocking=False
|
346
|
+
)
|
347
|
+
|
348
|
+
def get_flat_data(self, indices):
|
349
|
+
return self.kv_buffer[:, indices]
|
350
|
+
|
351
|
+
def get_flat_data_by_layer(self, indices, layer_id):
|
352
|
+
return self.kv_buffer[layer_id - self.start_layer, indices]
|
353
|
+
|
354
|
+
def assign_flat_data(self, indices, flat_data):
|
355
|
+
self.kv_buffer[:, indices] = flat_data
|
356
|
+
|
357
|
+
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
358
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
359
|
+
for i in range(len(device_indices_cpu)):
|
360
|
+
h_index = host_indices[i * self.page_size]
|
361
|
+
d_index = device_indices_cpu[i]
|
362
|
+
for j in range(self.layer_num):
|
363
|
+
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
364
|
+
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
365
|
+
non_blocking=True,
|
366
|
+
)
|
367
|
+
|
368
|
+
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
369
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
370
|
+
for i in range(len(device_indices_cpu)):
|
371
|
+
h_index = host_indices[i * self.page_size]
|
372
|
+
d_index = device_indices_cpu[i]
|
373
|
+
device_pool.kv_buffer[layer_id - self.start_layer][
|
374
|
+
d_index : d_index + self.page_size
|
375
|
+
].copy_(
|
376
|
+
self.kv_buffer[
|
377
|
+
layer_id - self.start_layer, h_index : h_index + self.page_size
|
378
|
+
],
|
379
|
+
non_blocking=True,
|
380
|
+
)
|
@@ -461,23 +461,47 @@ class RadixCache(BasePrefixCache):
|
|
461
461
|
return ret_list
|
462
462
|
|
463
463
|
def _record_store_event(self, node: TreeNode):
|
464
|
+
# One BlockStored per ``page_size`` chunk.
|
464
465
|
if self.enable_kv_cache_events:
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
466
|
+
# First chunk links to the last page of the parent node (if any).
|
467
|
+
if node.parent is None:
|
468
|
+
parent_block_hash = None
|
469
|
+
else:
|
470
|
+
last_page_start = (
|
471
|
+
(len(node.parent.key) - 1) // self.page_size
|
472
|
+
) * self.page_size
|
473
|
+
parent_parent_tokens = node.parent.key[last_page_start:]
|
474
|
+
parent_block_hash = hash(tuple(parent_parent_tokens))
|
475
|
+
|
476
|
+
for start in range(0, len(node.key), self.page_size):
|
477
|
+
page_tokens = node.key[start : start + self.page_size]
|
478
|
+
if not page_tokens:
|
479
|
+
continue
|
480
|
+
|
481
|
+
block_hash = hash(tuple(page_tokens))
|
482
|
+
|
483
|
+
self.kv_event_queue.append(
|
484
|
+
BlockStored(
|
485
|
+
block_hashes=[block_hash],
|
486
|
+
parent_block_hash=parent_block_hash,
|
487
|
+
token_ids=page_tokens,
|
488
|
+
block_size=len(page_tokens),
|
489
|
+
lora_id=None,
|
490
|
+
)
|
474
491
|
)
|
475
|
-
|
492
|
+
|
493
|
+
# Chain next chunk to this one.
|
494
|
+
parent_block_hash = block_hash
|
476
495
|
|
477
496
|
def _record_remove_event(self, node: TreeNode):
|
497
|
+
# One BlockRemoved per chunk.
|
478
498
|
if self.enable_kv_cache_events:
|
479
|
-
|
480
|
-
|
499
|
+
for start in range(0, len(node.key), self.page_size):
|
500
|
+
page_tokens = node.key[start : start + self.page_size]
|
501
|
+
if not page_tokens:
|
502
|
+
continue
|
503
|
+
block_hash = hash(tuple(page_tokens))
|
504
|
+
self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
|
481
505
|
|
482
506
|
def _record_all_cleared_event(self):
|
483
507
|
if self.enable_kv_cache_events:
|
@@ -17,12 +17,14 @@ from __future__ import annotations
|
|
17
17
|
|
18
18
|
import bisect
|
19
19
|
import inspect
|
20
|
+
import logging
|
20
21
|
import os
|
21
22
|
from contextlib import contextmanager
|
22
23
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
23
24
|
|
24
25
|
import torch
|
25
26
|
import tqdm
|
27
|
+
from torch.profiler import ProfilerActivity, profile
|
26
28
|
|
27
29
|
from sglang.srt.custom_op import CustomOp
|
28
30
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
@@ -40,11 +42,14 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
40
42
|
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
41
43
|
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
42
44
|
from sglang.srt.utils import (
|
45
|
+
empty_context,
|
43
46
|
get_available_gpu_memory,
|
44
47
|
get_device_memory_capacity,
|
45
48
|
rank0_log,
|
46
49
|
)
|
47
50
|
|
51
|
+
logger = logging.getLogger(__name__)
|
52
|
+
|
48
53
|
if TYPE_CHECKING:
|
49
54
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
50
55
|
|
@@ -147,10 +152,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
147
152
|
)
|
148
153
|
|
149
154
|
gpu_mem = get_device_memory_capacity()
|
150
|
-
if gpu_mem is not None
|
151
|
-
|
152
|
-
|
153
|
-
|
155
|
+
if gpu_mem is not None:
|
156
|
+
if gpu_mem > 90 * 1024: # H200, H20
|
157
|
+
capture_bs += list(range(160, 257, 8))
|
158
|
+
if gpu_mem > 160 * 1000: # B200, MI300
|
159
|
+
capture_bs += list(range(256, 513, 16))
|
154
160
|
|
155
161
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
156
162
|
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
@@ -207,6 +213,9 @@ class CudaGraphRunner:
|
|
207
213
|
model_runner.server_args.enable_two_batch_overlap
|
208
214
|
)
|
209
215
|
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
216
|
+
self.enable_profile_cuda_graph = (
|
217
|
+
model_runner.server_args.enable_profile_cuda_graph
|
218
|
+
)
|
210
219
|
self.tp_size = model_runner.server_args.tp_size
|
211
220
|
self.dp_size = model_runner.server_args.dp_size
|
212
221
|
self.pp_size = model_runner.server_args.pp_size
|
@@ -226,6 +235,10 @@ class CudaGraphRunner:
|
|
226
235
|
self.model_runner.server_args.speculative_num_draft_tokens
|
227
236
|
)
|
228
237
|
|
238
|
+
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
|
239
|
+
if model_runner.server_args.enable_return_hidden_states:
|
240
|
+
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
241
|
+
|
229
242
|
# Attention backend
|
230
243
|
self.max_bs = max(self.capture_bs)
|
231
244
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
@@ -333,50 +346,91 @@ class CudaGraphRunner:
|
|
333
346
|
else True
|
334
347
|
)
|
335
348
|
|
349
|
+
requested_capture_hidden_mode = max(
|
350
|
+
forward_batch.capture_hidden_mode,
|
351
|
+
(
|
352
|
+
forward_batch.spec_info.capture_hidden_mode
|
353
|
+
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
|
354
|
+
is not None
|
355
|
+
else CaptureHiddenMode.NULL
|
356
|
+
),
|
357
|
+
)
|
358
|
+
capture_hidden_mode_matches = (
|
359
|
+
requested_capture_hidden_mode == CaptureHiddenMode.NULL
|
360
|
+
or requested_capture_hidden_mode == self.capture_hidden_mode
|
361
|
+
)
|
336
362
|
is_tbo_supported = (
|
337
363
|
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
338
364
|
)
|
339
365
|
|
340
|
-
return
|
366
|
+
return (
|
367
|
+
is_bs_supported
|
368
|
+
and is_encoder_lens_supported
|
369
|
+
and is_tbo_supported
|
370
|
+
and capture_hidden_mode_matches
|
371
|
+
)
|
341
372
|
|
342
|
-
def capture(self):
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
# Reverse the order to enable better memory sharing across cuda graphs.
|
349
|
-
capture_range = (
|
350
|
-
tqdm.tqdm(list(reversed(self.capture_bs)))
|
351
|
-
if get_tensor_model_parallel_rank() == 0
|
352
|
-
else reversed(self.capture_bs)
|
373
|
+
def capture(self) -> None:
|
374
|
+
profile_context = empty_context()
|
375
|
+
if self.enable_profile_cuda_graph:
|
376
|
+
profile_context = profile(
|
377
|
+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
378
|
+
record_shapes=True,
|
353
379
|
)
|
354
|
-
for bs in capture_range:
|
355
|
-
if get_tensor_model_parallel_rank() == 0:
|
356
|
-
avail_mem = get_available_gpu_memory(
|
357
|
-
self.model_runner.device,
|
358
|
-
self.model_runner.gpu_id,
|
359
|
-
empty_cache=False,
|
360
|
-
)
|
361
|
-
capture_range.set_description(
|
362
|
-
f"Capturing batches ({avail_mem=:.2f} GB)"
|
363
|
-
)
|
364
|
-
|
365
|
-
with patch_model(
|
366
|
-
self.model_runner.model,
|
367
|
-
bs in self.compile_bs,
|
368
|
-
num_tokens=bs * self.num_tokens_per_bs,
|
369
|
-
tp_group=self.model_runner.tp_group,
|
370
|
-
) as forward:
|
371
|
-
(
|
372
|
-
graph,
|
373
|
-
output_buffers,
|
374
|
-
) = self.capture_one_batch_size(bs, forward)
|
375
|
-
self.graphs[bs] = graph
|
376
|
-
self.output_buffers[bs] = output_buffers
|
377
380
|
|
378
|
-
|
379
|
-
|
381
|
+
with graph_capture() as graph_capture_context:
|
382
|
+
with profile_context as prof:
|
383
|
+
self.stream = graph_capture_context.stream
|
384
|
+
avail_mem = get_available_gpu_memory(
|
385
|
+
self.model_runner.device,
|
386
|
+
self.model_runner.gpu_id,
|
387
|
+
empty_cache=False,
|
388
|
+
)
|
389
|
+
# Reverse the order to enable better memory sharing across cuda graphs.
|
390
|
+
capture_range = (
|
391
|
+
tqdm.tqdm(list(reversed(self.capture_bs)))
|
392
|
+
if get_tensor_model_parallel_rank() == 0
|
393
|
+
else reversed(self.capture_bs)
|
394
|
+
)
|
395
|
+
for i, bs in enumerate(capture_range):
|
396
|
+
if get_tensor_model_parallel_rank() == 0:
|
397
|
+
avail_mem = get_available_gpu_memory(
|
398
|
+
self.model_runner.device,
|
399
|
+
self.model_runner.gpu_id,
|
400
|
+
empty_cache=False,
|
401
|
+
)
|
402
|
+
capture_range.set_description(
|
403
|
+
f"Capturing batches ({avail_mem=:.2f} GB)"
|
404
|
+
)
|
405
|
+
|
406
|
+
with patch_model(
|
407
|
+
self.model_runner.model,
|
408
|
+
bs in self.compile_bs,
|
409
|
+
num_tokens=bs * self.num_tokens_per_bs,
|
410
|
+
tp_group=self.model_runner.tp_group,
|
411
|
+
) as forward:
|
412
|
+
(
|
413
|
+
graph,
|
414
|
+
output_buffers,
|
415
|
+
) = self.capture_one_batch_size(bs, forward)
|
416
|
+
self.graphs[bs] = graph
|
417
|
+
self.output_buffers[bs] = output_buffers
|
418
|
+
|
419
|
+
# Save gemlite cache after each capture
|
420
|
+
save_gemlite_cache()
|
421
|
+
|
422
|
+
if self.enable_profile_cuda_graph:
|
423
|
+
log_message = (
|
424
|
+
"Sorted by CUDA Time:\n"
|
425
|
+
+ prof.key_averages(group_by_input_shape=True).table(
|
426
|
+
sort_by="cuda_time_total", row_limit=10
|
427
|
+
)
|
428
|
+
+ "\n\nSorted by CPU Time:\n"
|
429
|
+
+ prof.key_averages(group_by_input_shape=True).table(
|
430
|
+
sort_by="cpu_time_total", row_limit=10
|
431
|
+
)
|
432
|
+
)
|
433
|
+
logger.info(log_message)
|
380
434
|
|
381
435
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
382
436
|
graph = torch.cuda.CUDAGraph()
|
@@ -443,7 +497,7 @@ class CudaGraphRunner:
|
|
443
497
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
444
498
|
attn_backend=self.model_runner.attn_backend,
|
445
499
|
out_cache_loc=out_cache_loc,
|
446
|
-
seq_lens_sum=seq_lens.sum(),
|
500
|
+
seq_lens_sum=seq_lens.sum().item(),
|
447
501
|
encoder_lens=encoder_lens,
|
448
502
|
return_logprob=False,
|
449
503
|
positions=positions,
|
@@ -509,21 +563,34 @@ class CudaGraphRunner:
|
|
509
563
|
return graph, out
|
510
564
|
|
511
565
|
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
512
|
-
|
513
|
-
|
566
|
+
|
567
|
+
# If the required capture_hidden_mode changes, we need to recapture the graph
|
568
|
+
|
569
|
+
# These are the different factors that can influence the capture_hidden_mode
|
570
|
+
capture_hidden_mode_required_by_forward_batch = (
|
571
|
+
forward_batch.capture_hidden_mode
|
572
|
+
)
|
573
|
+
capture_hidden_mode_required_by_spec_info = getattr(
|
514
574
|
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
515
575
|
)
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
576
|
+
capture_hidden_mode_required_for_returning_hidden_states = (
|
577
|
+
CaptureHiddenMode.FULL
|
578
|
+
if self.model_runner.server_args.enable_return_hidden_states
|
579
|
+
else CaptureHiddenMode.NULL
|
580
|
+
)
|
581
|
+
|
582
|
+
# Determine the highest capture_hidden_mode required
|
583
|
+
# (If we have FULL, we can emulate LAST or NULL)
|
584
|
+
# (If we have LAST, we can emulate NULL)
|
585
|
+
required_capture_hidden_mode = max(
|
586
|
+
capture_hidden_mode_required_by_forward_batch,
|
587
|
+
capture_hidden_mode_required_by_spec_info,
|
588
|
+
capture_hidden_mode_required_for_returning_hidden_states,
|
589
|
+
)
|
590
|
+
|
591
|
+
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
|
592
|
+
if self.capture_hidden_mode != required_capture_hidden_mode:
|
593
|
+
self.capture_hidden_mode = required_capture_hidden_mode
|
527
594
|
self.capture()
|
528
595
|
|
529
596
|
def replay_prepare(
|