sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,7 @@ KVCache actually holds the physical kv cache.
|
|
26
26
|
|
27
27
|
import abc
|
28
28
|
import logging
|
29
|
+
from contextlib import nullcontext
|
29
30
|
from typing import List, Optional, Tuple, Union
|
30
31
|
|
31
32
|
import numpy as np
|
@@ -33,8 +34,9 @@ import torch
|
|
33
34
|
import triton
|
34
35
|
import triton.language as tl
|
35
36
|
|
37
|
+
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
36
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
|
-
from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
|
39
|
+
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
|
38
40
|
|
39
41
|
logger = logging.getLogger(__name__)
|
40
42
|
|
@@ -52,6 +54,7 @@ class ReqToTokenPool:
|
|
52
54
|
device: str,
|
53
55
|
enable_memory_saver: bool,
|
54
56
|
):
|
57
|
+
|
55
58
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
56
59
|
enable=enable_memory_saver
|
57
60
|
)
|
@@ -59,7 +62,7 @@ class ReqToTokenPool:
|
|
59
62
|
self.size = size
|
60
63
|
self.max_context_len = max_context_len
|
61
64
|
self.device = device
|
62
|
-
with memory_saver_adapter.region():
|
65
|
+
with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
63
66
|
self.req_to_token = torch.zeros(
|
64
67
|
(size, max_context_len), dtype=torch.int32, device=device
|
65
68
|
)
|
@@ -119,6 +122,9 @@ class KVCache(abc.ABC):
|
|
119
122
|
enable=enable_memory_saver
|
120
123
|
)
|
121
124
|
|
125
|
+
# used for chunked cpu-offloading
|
126
|
+
self.cpu_offloading_chunk_size = 8192
|
127
|
+
|
122
128
|
@abc.abstractmethod
|
123
129
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
124
130
|
raise NotImplementedError()
|
@@ -153,83 +159,11 @@ class KVCache(abc.ABC):
|
|
153
159
|
def register_layer_transfer_counter(self, layer_transfer_counter):
|
154
160
|
self.layer_transfer_counter = layer_transfer_counter
|
155
161
|
|
156
|
-
|
157
|
-
class TokenToKVPoolAllocator:
|
158
|
-
"""An allocator managing the indices to kv cache data."""
|
159
|
-
|
160
|
-
def __init__(
|
161
|
-
self,
|
162
|
-
size: int,
|
163
|
-
dtype: torch.dtype,
|
164
|
-
device: str,
|
165
|
-
kvcache: KVCache,
|
166
|
-
):
|
167
|
-
self.size = size
|
168
|
-
self.dtype = dtype
|
169
|
-
self.device = device
|
170
|
-
self.page_size = 1
|
171
|
-
|
172
|
-
self.free_slots = None
|
173
|
-
self.is_not_in_free_group = True
|
174
|
-
self.free_group = []
|
175
|
-
self.clear()
|
176
|
-
|
177
|
-
self._kvcache = kvcache
|
178
|
-
|
179
|
-
def available_size(self):
|
180
|
-
return len(self.free_slots)
|
181
|
-
|
182
|
-
def debug_print(self) -> str:
|
183
|
-
return ""
|
184
|
-
|
185
|
-
def get_kvcache(self):
|
186
|
-
return self._kvcache
|
187
|
-
|
188
|
-
def alloc(self, need_size: int):
|
189
|
-
if need_size > len(self.free_slots):
|
190
|
-
return None
|
191
|
-
|
192
|
-
select_index = self.free_slots[:need_size]
|
193
|
-
self.free_slots = self.free_slots[need_size:]
|
194
|
-
return select_index
|
195
|
-
|
196
|
-
def free(self, free_index: torch.Tensor):
|
197
|
-
if free_index.numel() == 0:
|
198
|
-
return
|
199
|
-
|
200
|
-
if self.is_not_in_free_group:
|
201
|
-
self.free_slots = torch.cat((self.free_slots, free_index))
|
202
|
-
else:
|
203
|
-
self.free_group.append(free_index)
|
204
|
-
|
205
|
-
def free_group_begin(self):
|
206
|
-
self.is_not_in_free_group = False
|
207
|
-
self.free_group = []
|
208
|
-
|
209
|
-
def free_group_end(self):
|
210
|
-
self.is_not_in_free_group = True
|
211
|
-
if self.free_group:
|
212
|
-
self.free(torch.cat(self.free_group))
|
213
|
-
|
214
|
-
def backup_state(self):
|
215
|
-
return self.free_slots
|
216
|
-
|
217
|
-
def restore_state(self, free_slots):
|
218
|
-
self.free_slots = free_slots
|
219
|
-
|
220
|
-
def clear(self):
|
221
|
-
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
222
|
-
self.free_slots = torch.arange(
|
223
|
-
1, self.size + 1, dtype=torch.int64, device=self.device
|
224
|
-
)
|
225
|
-
self.is_not_in_free_group = True
|
226
|
-
self.free_group = []
|
227
|
-
|
228
162
|
def get_cpu_copy(self, indices):
|
229
|
-
|
163
|
+
raise NotImplementedError()
|
230
164
|
|
231
165
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
232
|
-
|
166
|
+
raise NotImplementedError()
|
233
167
|
|
234
168
|
|
235
169
|
class MHATokenToKVPool(KVCache):
|
@@ -260,10 +194,22 @@ class MHATokenToKVPool(KVCache):
|
|
260
194
|
|
261
195
|
self.head_num = head_num
|
262
196
|
self.head_dim = head_dim
|
197
|
+
|
198
|
+
# for disagg with nvlink
|
199
|
+
self.enable_custom_mem_pool = get_bool_env_var(
|
200
|
+
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
201
|
+
)
|
202
|
+
if self.enable_custom_mem_pool:
|
203
|
+
# TODO(shangming): abstract custom allocator class for more backends
|
204
|
+
from mooncake.allocator import NVLinkAllocator
|
205
|
+
|
206
|
+
allocator = NVLinkAllocator.get_allocator(self.device)
|
207
|
+
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
208
|
+
else:
|
209
|
+
self.custom_mem_pool = None
|
210
|
+
|
263
211
|
self._create_buffers()
|
264
212
|
|
265
|
-
# used for chunked cpu-offloading
|
266
|
-
self.chunk_size = 8192
|
267
213
|
self.layer_transfer_counter = None
|
268
214
|
self.device_module = torch.get_device_module(self.device)
|
269
215
|
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
@@ -274,25 +220,30 @@ class MHATokenToKVPool(KVCache):
|
|
274
220
|
)
|
275
221
|
|
276
222
|
def _create_buffers(self):
|
277
|
-
with self.memory_saver_adapter.region():
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
223
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
224
|
+
with (
|
225
|
+
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
226
|
+
if self.enable_custom_mem_pool
|
227
|
+
else nullcontext()
|
228
|
+
):
|
229
|
+
# [size, head_num, head_dim] for each layer
|
230
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
231
|
+
self.k_buffer = [
|
232
|
+
torch.zeros(
|
233
|
+
(self.size + self.page_size, self.head_num, self.head_dim),
|
234
|
+
dtype=self.store_dtype,
|
235
|
+
device=self.device,
|
236
|
+
)
|
237
|
+
for _ in range(self.layer_num)
|
238
|
+
]
|
239
|
+
self.v_buffer = [
|
240
|
+
torch.zeros(
|
241
|
+
(self.size + self.page_size, self.head_num, self.head_dim),
|
242
|
+
dtype=self.store_dtype,
|
243
|
+
device=self.device,
|
244
|
+
)
|
245
|
+
for _ in range(self.layer_num)
|
246
|
+
]
|
296
247
|
|
297
248
|
self.data_ptrs = torch.tensor(
|
298
249
|
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
@@ -349,13 +300,17 @@ class MHATokenToKVPool(KVCache):
|
|
349
300
|
]
|
350
301
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
351
302
|
|
303
|
+
def maybe_get_custom_mem_pool(self):
|
304
|
+
return self.custom_mem_pool
|
305
|
+
|
352
306
|
def get_cpu_copy(self, indices):
|
353
307
|
torch.cuda.synchronize()
|
354
308
|
kv_cache_cpu = []
|
309
|
+
chunk_size = self.cpu_offloading_chunk_size
|
355
310
|
for layer_id in range(self.layer_num):
|
356
311
|
kv_cache_cpu.append([])
|
357
|
-
for i in range(0, len(indices),
|
358
|
-
chunk_indices = indices[i : i +
|
312
|
+
for i in range(0, len(indices), chunk_size):
|
313
|
+
chunk_indices = indices[i : i + chunk_size]
|
359
314
|
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
|
360
315
|
"cpu", non_blocking=True
|
361
316
|
)
|
@@ -368,12 +323,13 @@ class MHATokenToKVPool(KVCache):
|
|
368
323
|
|
369
324
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
370
325
|
torch.cuda.synchronize()
|
326
|
+
chunk_size = self.cpu_offloading_chunk_size
|
371
327
|
for layer_id in range(self.layer_num):
|
372
|
-
for i in range(0, len(indices),
|
373
|
-
chunk_indices = indices[i : i +
|
328
|
+
for i in range(0, len(indices), chunk_size):
|
329
|
+
chunk_indices = indices[i : i + chunk_size]
|
374
330
|
k_cpu, v_cpu = (
|
375
|
-
kv_cache_cpu[layer_id][i //
|
376
|
-
kv_cache_cpu[layer_id][i //
|
331
|
+
kv_cache_cpu[layer_id][i // chunk_size][0],
|
332
|
+
kv_cache_cpu[layer_id][i // chunk_size][1],
|
377
333
|
)
|
378
334
|
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
|
379
335
|
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
|
@@ -569,16 +525,34 @@ class MLATokenToKVPool(KVCache):
|
|
569
525
|
self.kv_lora_rank = kv_lora_rank
|
570
526
|
self.qk_rope_head_dim = qk_rope_head_dim
|
571
527
|
|
572
|
-
with
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
528
|
+
# for disagg with nvlink
|
529
|
+
self.enable_custom_mem_pool = get_bool_env_var(
|
530
|
+
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
531
|
+
)
|
532
|
+
if self.enable_custom_mem_pool:
|
533
|
+
# TODO(shangming): abstract custom allocator class for more backends
|
534
|
+
from mooncake.allocator import NVLinkAllocator
|
535
|
+
|
536
|
+
allocator = NVLinkAllocator.get_allocator(self.device)
|
537
|
+
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
538
|
+
else:
|
539
|
+
self.custom_mem_pool = None
|
540
|
+
|
541
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
542
|
+
with (
|
543
|
+
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
544
|
+
if self.custom_mem_pool
|
545
|
+
else nullcontext()
|
546
|
+
):
|
547
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
548
|
+
self.kv_buffer = [
|
549
|
+
torch.zeros(
|
550
|
+
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
|
551
|
+
dtype=self.store_dtype,
|
552
|
+
device=device,
|
553
|
+
)
|
554
|
+
for _ in range(layer_num)
|
555
|
+
]
|
582
556
|
|
583
557
|
self.layer_transfer_counter = None
|
584
558
|
|
@@ -604,6 +578,9 @@ class MLATokenToKVPool(KVCache):
|
|
604
578
|
]
|
605
579
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
606
580
|
|
581
|
+
def maybe_get_custom_mem_pool(self):
|
582
|
+
return self.custom_mem_pool
|
583
|
+
|
607
584
|
def get_key_buffer(self, layer_id: int):
|
608
585
|
if self.layer_transfer_counter is not None:
|
609
586
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
@@ -677,6 +654,33 @@ class MLATokenToKVPool(KVCache):
|
|
677
654
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
678
655
|
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
|
679
656
|
|
657
|
+
def get_cpu_copy(self, indices):
|
658
|
+
torch.cuda.synchronize()
|
659
|
+
kv_cache_cpu = []
|
660
|
+
chunk_size = self.cpu_offloading_chunk_size
|
661
|
+
for layer_id in range(self.layer_num):
|
662
|
+
kv_cache_cpu.append([])
|
663
|
+
for i in range(0, len(indices), chunk_size):
|
664
|
+
chunk_indices = indices[i : i + chunk_size]
|
665
|
+
kv_cpu = self.kv_buffer[layer_id][chunk_indices].to(
|
666
|
+
"cpu", non_blocking=True
|
667
|
+
)
|
668
|
+
kv_cache_cpu[-1].append(kv_cpu)
|
669
|
+
torch.cuda.synchronize()
|
670
|
+
return kv_cache_cpu
|
671
|
+
|
672
|
+
def load_cpu_copy(self, kv_cache_cpu, indices):
|
673
|
+
torch.cuda.synchronize()
|
674
|
+
chunk_size = self.cpu_offloading_chunk_size
|
675
|
+
for layer_id in range(self.layer_num):
|
676
|
+
for i in range(0, len(indices), chunk_size):
|
677
|
+
chunk_indices = indices[i : i + chunk_size]
|
678
|
+
kv_cpu = kv_cache_cpu[layer_id][i // chunk_size]
|
679
|
+
assert kv_cpu.shape[0] == len(chunk_indices)
|
680
|
+
kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True)
|
681
|
+
self.kv_buffer[layer_id][chunk_indices] = kv_chunk
|
682
|
+
torch.cuda.synchronize()
|
683
|
+
|
680
684
|
|
681
685
|
class DoubleSparseTokenToKVPool(KVCache):
|
682
686
|
def __init__(
|
@@ -704,7 +708,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
704
708
|
end_layer,
|
705
709
|
)
|
706
710
|
|
707
|
-
with self.memory_saver_adapter.region():
|
711
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
708
712
|
# [size, head_num, head_dim] for each layer
|
709
713
|
self.k_buffer = [
|
710
714
|
torch.zeros(
|
@@ -23,7 +23,7 @@ import heapq
|
|
23
23
|
import time
|
24
24
|
from collections import defaultdict
|
25
25
|
from functools import partial
|
26
|
-
from typing import TYPE_CHECKING, List, Optional
|
26
|
+
from typing import TYPE_CHECKING, List, Optional
|
27
27
|
|
28
28
|
import torch
|
29
29
|
|
@@ -31,11 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
|
|
31
31
|
AllBlocksCleared,
|
32
32
|
BlockRemoved,
|
33
33
|
BlockStored,
|
34
|
-
KVCacheEvent,
|
35
34
|
)
|
36
|
-
from sglang.srt.
|
37
|
-
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
38
|
-
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
35
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
36
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
37
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
39
38
|
|
40
39
|
if TYPE_CHECKING:
|
41
40
|
from sglang.srt.managers.schedule_batch import Req
|
@@ -47,9 +46,9 @@ class TreeNode:
|
|
47
46
|
|
48
47
|
def __init__(self, id: Optional[int] = None):
|
49
48
|
self.children = defaultdict(TreeNode)
|
50
|
-
self.parent = None
|
51
|
-
self.key = None
|
52
|
-
self.value = None
|
49
|
+
self.parent: TreeNode = None
|
50
|
+
self.key: List[int] = None
|
51
|
+
self.value: Optional[torch.Tensor] = None
|
53
52
|
self.lock_ref = 0
|
54
53
|
self.last_access_time = time.monotonic()
|
55
54
|
|
@@ -57,7 +56,7 @@ class TreeNode:
|
|
57
56
|
# indicating the node is loading KV cache from host
|
58
57
|
self.loading = False
|
59
58
|
# store the host indices of KV cache
|
60
|
-
self.host_value = None
|
59
|
+
self.host_value: Optional[torch.Tensor] = None
|
61
60
|
|
62
61
|
self.id = TreeNode.counter if id is None else id
|
63
62
|
TreeNode.counter += 1
|
@@ -99,7 +98,7 @@ class RadixCache(BasePrefixCache):
|
|
99
98
|
def __init__(
|
100
99
|
self,
|
101
100
|
req_to_token_pool: ReqToTokenPool,
|
102
|
-
token_to_kv_pool_allocator:
|
101
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
103
102
|
page_size: int,
|
104
103
|
disable: bool = False,
|
105
104
|
enable_kv_cache_events: bool = False,
|
@@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache):
|
|
135
134
|
self.protected_size_ = 0
|
136
135
|
self._record_all_cleared_event()
|
137
136
|
|
138
|
-
def match_prefix(self, key: List[int], **kwargs) ->
|
137
|
+
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
139
138
|
"""Find the matching prefix from the radix tree.
|
140
139
|
Args:
|
141
140
|
key: A list of token IDs to find a matching prefix.
|
@@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache):
|
|
147
146
|
than the last node's value.
|
148
147
|
"""
|
149
148
|
if self.disable or len(key) == 0:
|
150
|
-
return (
|
151
|
-
torch.empty(
|
149
|
+
return MatchResult(
|
150
|
+
device_indices=torch.empty(
|
152
151
|
(0,),
|
153
152
|
dtype=torch.int64,
|
154
153
|
device=self.device,
|
155
154
|
),
|
156
|
-
self.root_node,
|
155
|
+
last_device_node=self.root_node,
|
156
|
+
last_host_node=self.root_node,
|
157
157
|
)
|
158
158
|
|
159
159
|
if self.page_size != 1:
|
@@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache):
|
|
165
165
|
value = torch.cat(value)
|
166
166
|
else:
|
167
167
|
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
168
|
-
return
|
168
|
+
return MatchResult(
|
169
|
+
device_indices=value,
|
170
|
+
last_device_node=last_node,
|
171
|
+
last_host_node=last_node,
|
172
|
+
)
|
169
173
|
|
170
174
|
def insert(self, key: List, value=None):
|
171
175
|
if self.disable:
|
@@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache):
|
|
235
239
|
)
|
236
240
|
|
237
241
|
# The prefix indices could be updated, reuse it
|
238
|
-
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
|
242
|
+
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
|
239
243
|
self.req_to_token_pool.write(
|
240
244
|
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
241
245
|
new_indices[len(req.prefix_indices) :],
|
@@ -46,6 +46,10 @@ from sglang.srt.utils import (
|
|
46
46
|
get_available_gpu_memory,
|
47
47
|
get_device_memory_capacity,
|
48
48
|
rank0_log,
|
49
|
+
require_attn_tp_gather,
|
50
|
+
require_gathered_buffer,
|
51
|
+
require_mlp_sync,
|
52
|
+
require_mlp_tp_gather,
|
49
53
|
)
|
50
54
|
|
51
55
|
logger = logging.getLogger(__name__)
|
@@ -207,8 +211,10 @@ class CudaGraphRunner:
|
|
207
211
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
208
212
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
209
213
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
210
|
-
self.
|
211
|
-
self.
|
214
|
+
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
215
|
+
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
216
|
+
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
217
|
+
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
212
218
|
self.enable_two_batch_overlap = (
|
213
219
|
model_runner.server_args.enable_two_batch_overlap
|
214
220
|
)
|
@@ -242,13 +248,13 @@ class CudaGraphRunner:
|
|
242
248
|
# Attention backend
|
243
249
|
self.max_bs = max(self.capture_bs)
|
244
250
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
245
|
-
|
246
|
-
self.
|
247
|
-
|
248
|
-
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
251
|
+
self.model_runner.attn_backend.init_cuda_graph_state(
|
252
|
+
self.max_bs, self.max_num_token
|
253
|
+
)
|
249
254
|
self.seq_len_fill_value = (
|
250
255
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
251
256
|
)
|
257
|
+
|
252
258
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
253
259
|
self.encoder_len_fill_value = 0
|
254
260
|
self.seq_lens_cpu = torch.full(
|
@@ -299,18 +305,30 @@ class CudaGraphRunner:
|
|
299
305
|
else:
|
300
306
|
self.encoder_lens = None
|
301
307
|
|
302
|
-
if self.
|
303
|
-
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
308
|
+
if self.require_gathered_buffer:
|
304
309
|
self.gathered_buffer = torch.zeros(
|
305
310
|
(
|
306
|
-
self.
|
311
|
+
self.max_num_token,
|
307
312
|
self.model_runner.model_config.hidden_size,
|
308
313
|
),
|
309
314
|
dtype=self.model_runner.dtype,
|
310
315
|
)
|
311
|
-
self.
|
312
|
-
|
313
|
-
|
316
|
+
if self.require_mlp_tp_gather:
|
317
|
+
self.global_num_tokens_gpu = torch.zeros(
|
318
|
+
(self.dp_size,), dtype=torch.int32
|
319
|
+
)
|
320
|
+
else:
|
321
|
+
assert self.require_attn_tp_gather
|
322
|
+
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
323
|
+
|
324
|
+
self.custom_mask = torch.ones(
|
325
|
+
(
|
326
|
+
(self.seq_lens.sum().item() + self.max_num_token)
|
327
|
+
* self.num_tokens_per_bs
|
328
|
+
),
|
329
|
+
dtype=torch.bool,
|
330
|
+
device="cuda",
|
331
|
+
)
|
314
332
|
|
315
333
|
# Capture
|
316
334
|
try:
|
@@ -322,20 +340,23 @@ class CudaGraphRunner:
|
|
322
340
|
)
|
323
341
|
|
324
342
|
def can_run(self, forward_batch: ForwardBatch):
|
325
|
-
if self.
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
if self.disable_padding
|
331
|
-
else total_global_tokens <= self.max_bs
|
343
|
+
if self.require_mlp_tp_gather:
|
344
|
+
cuda_graph_bs = (
|
345
|
+
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
346
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
347
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
332
348
|
)
|
333
349
|
else:
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
350
|
+
cuda_graph_bs = forward_batch.batch_size
|
351
|
+
|
352
|
+
is_bs_supported = (
|
353
|
+
cuda_graph_bs in self.graphs
|
354
|
+
if self.disable_padding
|
355
|
+
else cuda_graph_bs <= self.max_bs
|
356
|
+
)
|
357
|
+
|
358
|
+
if self.require_mlp_sync:
|
359
|
+
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
339
360
|
|
340
361
|
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
341
362
|
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
@@ -456,11 +477,11 @@ class CudaGraphRunner:
|
|
456
477
|
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
457
478
|
)
|
458
479
|
|
459
|
-
if self.
|
480
|
+
if self.require_mlp_tp_gather:
|
460
481
|
self.global_num_tokens_gpu.copy_(
|
461
482
|
torch.tensor(
|
462
483
|
[
|
463
|
-
num_tokens // self.dp_size + (i <
|
484
|
+
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
464
485
|
for i in range(self.dp_size)
|
465
486
|
],
|
466
487
|
dtype=torch.int32,
|
@@ -469,6 +490,16 @@ class CudaGraphRunner:
|
|
469
490
|
)
|
470
491
|
global_num_tokens = self.global_num_tokens_gpu
|
471
492
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
493
|
+
elif self.require_attn_tp_gather:
|
494
|
+
self.global_num_tokens_gpu.copy_(
|
495
|
+
torch.tensor(
|
496
|
+
[num_tokens],
|
497
|
+
dtype=torch.int32,
|
498
|
+
device=input_ids.device,
|
499
|
+
)
|
500
|
+
)
|
501
|
+
global_num_tokens = self.global_num_tokens_gpu
|
502
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
472
503
|
else:
|
473
504
|
global_num_tokens = None
|
474
505
|
gathered_buffer = None
|
@@ -604,15 +635,18 @@ class CudaGraphRunner:
|
|
604
635
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
605
636
|
|
606
637
|
# Pad
|
607
|
-
if self.
|
608
|
-
|
609
|
-
|
638
|
+
if self.require_mlp_tp_gather:
|
639
|
+
total_batch_size = (
|
640
|
+
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
|
641
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
642
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
610
643
|
)
|
644
|
+
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
611
645
|
else:
|
612
646
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
613
647
|
bs = self.capture_bs[index]
|
614
648
|
if bs != raw_bs:
|
615
|
-
self.seq_lens.fill_(
|
649
|
+
self.seq_lens.fill_(self.seq_len_fill_value)
|
616
650
|
self.out_cache_loc.zero_()
|
617
651
|
|
618
652
|
# Common inputs
|
@@ -624,7 +658,7 @@ class CudaGraphRunner:
|
|
624
658
|
|
625
659
|
if forward_batch.seq_lens_cpu is not None:
|
626
660
|
if bs != raw_bs:
|
627
|
-
self.seq_lens_cpu.fill_(
|
661
|
+
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
628
662
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
629
663
|
|
630
664
|
if pp_proxy_tensors:
|
@@ -636,27 +670,28 @@ class CudaGraphRunner:
|
|
636
670
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
637
671
|
if forward_batch.mrope_positions is not None:
|
638
672
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
639
|
-
if self.
|
673
|
+
if self.require_gathered_buffer:
|
640
674
|
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
641
675
|
if enable_num_token_non_padded(self.model_runner.server_args):
|
642
676
|
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
643
677
|
if self.enable_two_batch_overlap:
|
644
678
|
self.tbo_plugin.replay_prepare(
|
645
|
-
forward_mode=
|
679
|
+
forward_mode=self.capture_forward_mode,
|
646
680
|
bs=bs,
|
647
681
|
num_token_non_padded=len(forward_batch.input_ids),
|
648
682
|
)
|
649
|
-
|
683
|
+
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
|
684
|
+
forward_batch.spec_info.custom_mask = self.custom_mask
|
650
685
|
# Attention backend
|
651
686
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
652
687
|
bs,
|
653
|
-
self.req_pool_indices,
|
654
|
-
self.seq_lens,
|
655
|
-
forward_batch.seq_lens_sum + (bs - raw_bs),
|
656
|
-
self.encoder_lens,
|
657
|
-
|
688
|
+
self.req_pool_indices[:bs],
|
689
|
+
self.seq_lens[:bs],
|
690
|
+
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
691
|
+
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
692
|
+
self.capture_forward_mode,
|
658
693
|
forward_batch.spec_info,
|
659
|
-
seq_lens_cpu=self.seq_lens_cpu,
|
694
|
+
seq_lens_cpu=self.seq_lens_cpu[:bs],
|
660
695
|
)
|
661
696
|
|
662
697
|
# Store fields
|
@@ -704,11 +739,7 @@ class CudaGraphRunner:
|
|
704
739
|
else:
|
705
740
|
spec_info = EagleVerifyInput(
|
706
741
|
draft_token=None,
|
707
|
-
custom_mask=
|
708
|
-
(num_tokens * self.model_runner.model_config.context_len),
|
709
|
-
dtype=torch.bool,
|
710
|
-
device="cuda",
|
711
|
-
),
|
742
|
+
custom_mask=self.custom_mask,
|
712
743
|
positions=None,
|
713
744
|
retrive_index=None,
|
714
745
|
retrive_next_token=None,
|