sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,9 @@ class ChunkCache(BasePrefixCache):
|
|
38
38
|
|
39
39
|
def cache_finished_req(self, req: Req):
|
40
40
|
kv_indices = self.req_to_token_pool.req_to_token[
|
41
|
-
req.req_pool_idx,
|
41
|
+
req.req_pool_idx,
|
42
|
+
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
|
43
|
+
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
|
42
44
|
]
|
43
45
|
self.req_to_token_pool.free(req.req_pool_idx)
|
44
46
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
@@ -335,13 +335,13 @@ class HiRadixCache(RadixCache):
|
|
335
335
|
return value, last_node
|
336
336
|
|
337
337
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
338
|
-
node.last_access_time = time.
|
338
|
+
node.last_access_time = time.monotonic()
|
339
339
|
child_key = self.get_child_key_fn(key)
|
340
340
|
value = []
|
341
341
|
|
342
342
|
while len(key) > 0 and child_key in node.children.keys():
|
343
343
|
child = node.children[child_key]
|
344
|
-
child.last_access_time = time.
|
344
|
+
child.last_access_time = time.monotonic()
|
345
345
|
prefix_len = self.key_match_fn(child.key, key)
|
346
346
|
if prefix_len < len(child.key):
|
347
347
|
new_node = self._split_node(child.key, child, prefix_len)
|
@@ -386,7 +386,7 @@ class HiRadixCache(RadixCache):
|
|
386
386
|
return new_node
|
387
387
|
|
388
388
|
def _insert_helper(self, node: TreeNode, key: List, value):
|
389
|
-
node.last_access_time = time.
|
389
|
+
node.last_access_time = time.monotonic()
|
390
390
|
if len(key) == 0:
|
391
391
|
return 0
|
392
392
|
|
@@ -395,7 +395,7 @@ class HiRadixCache(RadixCache):
|
|
395
395
|
|
396
396
|
while len(key) > 0 and child_key in node.children.keys():
|
397
397
|
node = node.children[child_key]
|
398
|
-
node.last_access_time = time.
|
398
|
+
node.last_access_time = time.monotonic()
|
399
399
|
prefix_len = self.key_match_fn(node.key, key)
|
400
400
|
|
401
401
|
if prefix_len == len(node.key):
|
@@ -38,11 +38,17 @@ import triton
|
|
38
38
|
import triton.language as tl
|
39
39
|
|
40
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
-
from sglang.srt.utils import
|
41
|
+
from sglang.srt.utils import (
|
42
|
+
debug_timing,
|
43
|
+
get_compiler_backend,
|
44
|
+
is_cuda,
|
45
|
+
next_power_of_2,
|
46
|
+
)
|
42
47
|
|
43
48
|
logger = logging.getLogger(__name__)
|
44
49
|
|
45
50
|
GB = 1024 * 1024 * 1024
|
51
|
+
_is_cuda = is_cuda()
|
46
52
|
|
47
53
|
|
48
54
|
class ReqToTokenPool:
|
@@ -94,6 +100,33 @@ class ReqToTokenPool:
|
|
94
100
|
|
95
101
|
|
96
102
|
class KVCache(abc.ABC):
|
103
|
+
@abc.abstractmethod
|
104
|
+
def __init__(
|
105
|
+
self,
|
106
|
+
size: int,
|
107
|
+
page_size: int,
|
108
|
+
dtype: torch.dtype,
|
109
|
+
layer_num: int,
|
110
|
+
device: str,
|
111
|
+
enable_memory_saver: bool,
|
112
|
+
start_layer: Optional[int] = None,
|
113
|
+
end_layer: Optional[int] = None,
|
114
|
+
):
|
115
|
+
self.size = size
|
116
|
+
self.page_size = page_size
|
117
|
+
self.dtype = dtype
|
118
|
+
self.device = device
|
119
|
+
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
120
|
+
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
121
|
+
self.store_dtype = torch.uint8
|
122
|
+
else:
|
123
|
+
self.store_dtype = dtype
|
124
|
+
self.layer_num = layer_num
|
125
|
+
self.start_layer = start_layer or 0
|
126
|
+
self.end_layer = end_layer or layer_num - 1
|
127
|
+
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
128
|
+
enable=enable_memory_saver
|
129
|
+
)
|
97
130
|
|
98
131
|
@abc.abstractmethod
|
99
132
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
@@ -217,30 +250,24 @@ class MHATokenToKVPool(KVCache):
|
|
217
250
|
start_layer: Optional[int] = None,
|
218
251
|
end_layer: Optional[int] = None,
|
219
252
|
):
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
230
|
-
enable=enable_memory_saver
|
253
|
+
super().__init__(
|
254
|
+
size,
|
255
|
+
page_size,
|
256
|
+
dtype,
|
257
|
+
layer_num,
|
258
|
+
device,
|
259
|
+
enable_memory_saver,
|
260
|
+
start_layer,
|
261
|
+
end_layer,
|
231
262
|
)
|
232
263
|
|
233
264
|
self.head_num = head_num
|
234
265
|
self.head_dim = head_dim
|
235
|
-
self.layer_num = layer_num
|
236
266
|
self._create_buffers()
|
237
|
-
self.start_layer = start_layer or 0
|
238
|
-
self.end_layer = end_layer or layer_num - 1
|
239
267
|
|
240
268
|
self.layer_transfer_counter = None
|
241
|
-
self.capture_mode = False
|
242
269
|
self.device_module = torch.get_device_module(self.device)
|
243
|
-
self.alt_stream = self.device_module.Stream()
|
270
|
+
self.alt_stream = self.device_module.Stream() if is_cuda else None
|
244
271
|
|
245
272
|
k_size, v_size = self.get_kv_size_bytes()
|
246
273
|
logger.info(
|
@@ -357,6 +384,8 @@ class MHATokenToKVPool(KVCache):
|
|
357
384
|
k_scale: Optional[float] = None,
|
358
385
|
v_scale: Optional[float] = None,
|
359
386
|
):
|
387
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
388
|
+
|
360
389
|
layer_id = layer.layer_id
|
361
390
|
if cache_k.dtype != self.dtype:
|
362
391
|
if k_scale is not None:
|
@@ -370,7 +399,7 @@ class MHATokenToKVPool(KVCache):
|
|
370
399
|
cache_k = cache_k.view(self.store_dtype)
|
371
400
|
cache_v = cache_v.view(self.store_dtype)
|
372
401
|
|
373
|
-
if
|
402
|
+
if get_is_capture_mode() and self.alt_stream is not None:
|
374
403
|
# Overlap the copy of K and V cache for small batch size
|
375
404
|
current_stream = self.device_module.current_stream()
|
376
405
|
self.alt_stream.wait_stream(current_stream)
|
@@ -493,26 +522,21 @@ class MLATokenToKVPool(KVCache):
|
|
493
522
|
start_layer: Optional[int] = None,
|
494
523
|
end_layer: Optional[int] = None,
|
495
524
|
):
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
525
|
+
super().__init__(
|
526
|
+
size,
|
527
|
+
page_size,
|
528
|
+
dtype,
|
529
|
+
layer_num,
|
530
|
+
device,
|
531
|
+
enable_memory_saver,
|
532
|
+
start_layer,
|
533
|
+
end_layer,
|
534
|
+
)
|
535
|
+
|
505
536
|
self.kv_lora_rank = kv_lora_rank
|
506
537
|
self.qk_rope_head_dim = qk_rope_head_dim
|
507
|
-
self.layer_num = layer_num
|
508
|
-
self.start_layer = start_layer or 0
|
509
|
-
self.end_layer = end_layer or layer_num - 1
|
510
|
-
|
511
|
-
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
512
|
-
enable=enable_memory_saver
|
513
|
-
)
|
514
538
|
|
515
|
-
with memory_saver_adapter.region():
|
539
|
+
with self.memory_saver_adapter.region():
|
516
540
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
517
541
|
self.kv_buffer = [
|
518
542
|
torch.zeros(
|
@@ -524,7 +548,6 @@ class MLATokenToKVPool(KVCache):
|
|
524
548
|
]
|
525
549
|
|
526
550
|
self.layer_transfer_counter = None
|
527
|
-
self.page_size = page_size
|
528
551
|
|
529
552
|
kv_size = self.get_kv_size_bytes()
|
530
553
|
logger.info(
|
@@ -637,20 +660,18 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
637
660
|
start_layer: Optional[int] = None,
|
638
661
|
end_layer: Optional[int] = None,
|
639
662
|
):
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
650
|
-
enable=enable_memory_saver
|
663
|
+
super().__init__(
|
664
|
+
size,
|
665
|
+
page_size,
|
666
|
+
dtype,
|
667
|
+
layer_num,
|
668
|
+
device,
|
669
|
+
enable_memory_saver,
|
670
|
+
start_layer,
|
671
|
+
end_layer,
|
651
672
|
)
|
652
673
|
|
653
|
-
with memory_saver_adapter.region():
|
674
|
+
with self.memory_saver_adapter.region():
|
654
675
|
# [size, head_num, head_dim] for each layer
|
655
676
|
self.k_buffer = [
|
656
677
|
torch.zeros(
|
@@ -673,9 +694,6 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
673
694
|
for _ in range(layer_num)
|
674
695
|
]
|
675
696
|
|
676
|
-
self.start_layer = start_layer or 0
|
677
|
-
self.end_layer = end_layer or layer_num - 1
|
678
|
-
|
679
697
|
def get_key_buffer(self, layer_id: int):
|
680
698
|
return self.k_buffer[layer_id - self.start_layer]
|
681
699
|
|
@@ -743,7 +761,7 @@ class HostKVCache(abc.ABC):
|
|
743
761
|
|
744
762
|
def __init__(
|
745
763
|
self,
|
746
|
-
device_pool:
|
764
|
+
device_pool: KVCache,
|
747
765
|
host_to_device_ratio: float,
|
748
766
|
host_size: int,
|
749
767
|
pin_memory: bool,
|
@@ -915,6 +933,8 @@ class HostKVCache(abc.ABC):
|
|
915
933
|
|
916
934
|
|
917
935
|
class MHATokenToKVPoolHost(HostKVCache):
|
936
|
+
device_pool: MHATokenToKVPool
|
937
|
+
|
918
938
|
def __init__(
|
919
939
|
self,
|
920
940
|
device_pool: MHATokenToKVPool,
|
@@ -998,6 +1018,8 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
998
1018
|
|
999
1019
|
|
1000
1020
|
class MLATokenToKVPoolHost(HostKVCache):
|
1021
|
+
device_pool: MLATokenToKVPool
|
1022
|
+
|
1001
1023
|
def __init__(
|
1002
1024
|
self,
|
1003
1025
|
device_pool: MLATokenToKVPool,
|
@@ -0,0 +1,45 @@
|
|
1
|
+
from typing import Dict
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
|
6
|
+
class MultiModalCache:
|
7
|
+
"""MultiModalCache is used to store vlm encoder results"""
|
8
|
+
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
max_size: int,
|
12
|
+
):
|
13
|
+
self.max_size = max_size
|
14
|
+
self.mm_cache: Dict[int, torch.Tensor] = {}
|
15
|
+
self.current_size = 0
|
16
|
+
|
17
|
+
def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
|
18
|
+
if mm_hash in self.mm_cache:
|
19
|
+
return True
|
20
|
+
data_size = self._get_tensor_size(embedding)
|
21
|
+
if self.current_size + data_size > self.max_size:
|
22
|
+
return False
|
23
|
+
self.mm_cache[mm_hash] = embedding
|
24
|
+
self.current_size += data_size
|
25
|
+
return True
|
26
|
+
|
27
|
+
def get(self, mm_hash: int) -> torch.Tensor:
|
28
|
+
return self.mm_cache.get(mm_hash)
|
29
|
+
|
30
|
+
def free(self, mm_hash: int) -> bool:
|
31
|
+
if mm_hash not in self.mm_cache:
|
32
|
+
return False
|
33
|
+
old_embedding = self.mm_cache.pop(mm_hash)
|
34
|
+
self.current_size -= self._get_tensor_size(old_embedding)
|
35
|
+
return True
|
36
|
+
|
37
|
+
def clear(self):
|
38
|
+
self.mm_cache.clear()
|
39
|
+
self.current_size = 0
|
40
|
+
|
41
|
+
def _get_tensor_size(self, embedding: torch.Tensor):
|
42
|
+
return embedding.element_size() * embedding.numel()
|
43
|
+
|
44
|
+
def __len__(self):
|
45
|
+
return len(self.mm_cache)
|
@@ -27,6 +27,12 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
|
27
27
|
|
28
28
|
import torch
|
29
29
|
|
30
|
+
from sglang.srt.disaggregation.kv_events import (
|
31
|
+
AllBlocksCleared,
|
32
|
+
BlockRemoved,
|
33
|
+
BlockStored,
|
34
|
+
KVCacheEvent,
|
35
|
+
)
|
30
36
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
31
37
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
32
38
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
@@ -45,7 +51,7 @@ class TreeNode:
|
|
45
51
|
self.key = None
|
46
52
|
self.value = None
|
47
53
|
self.lock_ref = 0
|
48
|
-
self.last_access_time = time.
|
54
|
+
self.last_access_time = time.monotonic()
|
49
55
|
|
50
56
|
self.hit_count = 0
|
51
57
|
# indicating the node is loading KV cache from host
|
@@ -96,11 +102,14 @@ class RadixCache(BasePrefixCache):
|
|
96
102
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
97
103
|
page_size: int,
|
98
104
|
disable: bool = False,
|
105
|
+
enable_kv_cache_events: bool = False,
|
99
106
|
):
|
100
107
|
self.req_to_token_pool = req_to_token_pool
|
101
108
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
102
109
|
self.page_size = page_size
|
103
110
|
self.disable = disable
|
111
|
+
self.enable_kv_cache_events = enable_kv_cache_events
|
112
|
+
self.kv_event_queue = []
|
104
113
|
|
105
114
|
if self.token_to_kv_pool_allocator:
|
106
115
|
self.device = self.token_to_kv_pool_allocator.device
|
@@ -124,6 +133,7 @@ class RadixCache(BasePrefixCache):
|
|
124
133
|
self.root_node.lock_ref = 1
|
125
134
|
self.evictable_size_ = 0
|
126
135
|
self.protected_size_ = 0
|
136
|
+
self._record_all_cleared_event()
|
127
137
|
|
128
138
|
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
129
139
|
"""Find the matching prefix from the radix tree.
|
@@ -273,6 +283,8 @@ class RadixCache(BasePrefixCache):
|
|
273
283
|
if len(x.parent.children) == 0:
|
274
284
|
heapq.heappush(leaves, x.parent)
|
275
285
|
|
286
|
+
self._record_remove_event(x)
|
287
|
+
|
276
288
|
def inc_lock_ref(self, node: TreeNode):
|
277
289
|
if self.disable:
|
278
290
|
return 0
|
@@ -322,14 +334,14 @@ class RadixCache(BasePrefixCache):
|
|
322
334
|
##### Internal Helper Functions #####
|
323
335
|
|
324
336
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
325
|
-
node.last_access_time = time.
|
337
|
+
node.last_access_time = time.monotonic()
|
326
338
|
|
327
339
|
child_key = self.get_child_key_fn(key)
|
328
340
|
|
329
341
|
value = []
|
330
342
|
while len(key) > 0 and child_key in node.children.keys():
|
331
343
|
child = node.children[child_key]
|
332
|
-
child.last_access_time = time.
|
344
|
+
child.last_access_time = time.monotonic()
|
333
345
|
prefix_len = self.key_match_fn(child.key, key)
|
334
346
|
if prefix_len < len(child.key):
|
335
347
|
new_node = self._split_node(child.key, child, prefix_len)
|
@@ -348,6 +360,7 @@ class RadixCache(BasePrefixCache):
|
|
348
360
|
|
349
361
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
350
362
|
# new_node -> child
|
363
|
+
self._record_remove_event(child)
|
351
364
|
new_node = TreeNode()
|
352
365
|
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
353
366
|
new_node.parent = child.parent
|
@@ -358,10 +371,14 @@ class RadixCache(BasePrefixCache):
|
|
358
371
|
child.key = child.key[split_len:]
|
359
372
|
child.value = child.value[split_len:]
|
360
373
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
374
|
+
|
375
|
+
self._record_store_event(new_node)
|
376
|
+
self._record_store_event(child)
|
377
|
+
|
361
378
|
return new_node
|
362
379
|
|
363
380
|
def _insert_helper(self, node: TreeNode, key: List, value):
|
364
|
-
node.last_access_time = time.
|
381
|
+
node.last_access_time = time.monotonic()
|
365
382
|
if len(key) == 0:
|
366
383
|
return 0
|
367
384
|
|
@@ -370,7 +387,7 @@ class RadixCache(BasePrefixCache):
|
|
370
387
|
total_prefix_length = 0
|
371
388
|
while len(key) > 0 and child_key in node.children.keys():
|
372
389
|
node = node.children[child_key]
|
373
|
-
node.last_access_time = time.
|
390
|
+
node.last_access_time = time.monotonic()
|
374
391
|
prefix_len = self.key_match_fn(node.key, key)
|
375
392
|
total_prefix_length += prefix_len
|
376
393
|
key = key[prefix_len:]
|
@@ -390,6 +407,7 @@ class RadixCache(BasePrefixCache):
|
|
390
407
|
new_node.value = value
|
391
408
|
node.children[child_key] = new_node
|
392
409
|
self.evictable_size_ += len(value)
|
410
|
+
self._record_store_event(new_node)
|
393
411
|
return total_prefix_length
|
394
412
|
|
395
413
|
def _print_helper(self, node: TreeNode, indent: int):
|
@@ -442,6 +460,41 @@ class RadixCache(BasePrefixCache):
|
|
442
460
|
|
443
461
|
return ret_list
|
444
462
|
|
463
|
+
def _record_store_event(self, node: TreeNode):
|
464
|
+
if self.enable_kv_cache_events:
|
465
|
+
block_hash = hash(tuple(node.key))
|
466
|
+
parent_block_hash = hash(tuple(node.parent.key))
|
467
|
+
self.kv_event_queue.append(
|
468
|
+
BlockStored(
|
469
|
+
block_hashes=[block_hash],
|
470
|
+
parent_block_hash=parent_block_hash,
|
471
|
+
token_ids=node.key,
|
472
|
+
block_size=len(node.key),
|
473
|
+
lora_id=None,
|
474
|
+
)
|
475
|
+
)
|
476
|
+
|
477
|
+
def _record_remove_event(self, node: TreeNode):
|
478
|
+
if self.enable_kv_cache_events:
|
479
|
+
block_hash = hash(tuple(node.key))
|
480
|
+
self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
|
481
|
+
|
482
|
+
def _record_all_cleared_event(self):
|
483
|
+
if self.enable_kv_cache_events:
|
484
|
+
self.kv_event_queue.append(AllBlocksCleared())
|
485
|
+
|
486
|
+
def take_events(self):
|
487
|
+
"""Atomically takes all events and clears the queue.
|
488
|
+
|
489
|
+
Returns:
|
490
|
+
A list of KV cache events.
|
491
|
+
"""
|
492
|
+
if not self.enable_kv_cache_events:
|
493
|
+
return []
|
494
|
+
events = self.kv_event_queue
|
495
|
+
self.kv_event_queue = []
|
496
|
+
return events
|
497
|
+
|
445
498
|
|
446
499
|
if __name__ == "__main__":
|
447
500
|
tree = RadixCache(None, None, page_size=1, disable=False)
|
sglang/srt/metrics/collector.py
CHANGED
@@ -154,7 +154,7 @@ class SchedulerMetricsCollector:
|
|
154
154
|
from prometheus_client import Counter, Gauge
|
155
155
|
|
156
156
|
self.labels = labels
|
157
|
-
self.last_log_time = time.
|
157
|
+
self.last_log_time = time.perf_counter()
|
158
158
|
|
159
159
|
self.num_running_reqs = Gauge(
|
160
160
|
name="sglang:num_running_reqs",
|
@@ -294,7 +294,7 @@ class SchedulerMetricsCollector:
|
|
294
294
|
self.num_decode_transfer_queue_reqs, stats.num_decode_transfer_queue_reqs
|
295
295
|
)
|
296
296
|
|
297
|
-
self.last_log_time = time.
|
297
|
+
self.last_log_time = time.perf_counter()
|
298
298
|
|
299
299
|
|
300
300
|
class TokenizerMetricsCollector:
|
sglang/srt/mm_utils.py
CHANGED
@@ -36,6 +36,16 @@ from io import BytesIO
|
|
36
36
|
import numpy as np
|
37
37
|
from PIL import Image
|
38
38
|
|
39
|
+
from sglang.srt.utils import flatten_nested_list
|
40
|
+
|
41
|
+
|
42
|
+
def has_valid_data(data) -> bool:
|
43
|
+
if data is None:
|
44
|
+
return False
|
45
|
+
if isinstance(data, list):
|
46
|
+
return any(has_valid_data(item) for item in flatten_nested_list(data))
|
47
|
+
return True
|
48
|
+
|
39
49
|
|
40
50
|
def select_best_resolution(original_size, possible_resolutions):
|
41
51
|
"""
|
@@ -30,6 +30,7 @@ from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_captur
|
|
30
30
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
31
31
|
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
32
32
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
33
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
33
34
|
from sglang.srt.model_executor.forward_batch_info import (
|
34
35
|
CaptureHiddenMode,
|
35
36
|
ForwardBatch,
|
@@ -46,6 +47,13 @@ from sglang.srt.utils import (
|
|
46
47
|
if TYPE_CHECKING:
|
47
48
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
48
49
|
|
50
|
+
# Detect whether the current forward pass is in capture mode
|
51
|
+
is_capture_mode = False
|
52
|
+
|
53
|
+
|
54
|
+
def get_is_capture_mode():
|
55
|
+
return is_capture_mode
|
56
|
+
|
49
57
|
|
50
58
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
51
59
|
for sub in model._modules.values():
|
@@ -210,7 +218,10 @@ class CudaGraphRunner:
|
|
210
218
|
# Attention backend
|
211
219
|
self.max_bs = max(self.capture_bs)
|
212
220
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
213
|
-
|
221
|
+
if global_server_args_dict["attention_backend"] == "flashmla":
|
222
|
+
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
223
|
+
else:
|
224
|
+
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
214
225
|
self.seq_len_fill_value = (
|
215
226
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
216
227
|
)
|
@@ -236,6 +247,7 @@ class CudaGraphRunner:
|
|
236
247
|
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
237
248
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
238
249
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
250
|
+
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
239
251
|
|
240
252
|
# pipeline parallelism
|
241
253
|
if self.pp_size > 1:
|
@@ -306,17 +318,12 @@ class CudaGraphRunner:
|
|
306
318
|
|
307
319
|
@contextmanager
|
308
320
|
def model_capture_mode(self):
|
309
|
-
|
310
|
-
|
311
|
-
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
312
|
-
self.model_runner.token_to_kv_pool.capture_mode = True
|
321
|
+
global is_capture_mode
|
322
|
+
is_capture_mode = True
|
313
323
|
|
314
324
|
yield
|
315
325
|
|
316
|
-
|
317
|
-
self.model_runner.model.capture_mode = False
|
318
|
-
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
319
|
-
self.model_runner.token_to_kv_pool.capture_mode = False
|
326
|
+
is_capture_mode = False
|
320
327
|
|
321
328
|
def can_run(self, forward_batch: ForwardBatch):
|
322
329
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
@@ -399,6 +406,7 @@ class CudaGraphRunner:
|
|
399
406
|
else:
|
400
407
|
encoder_lens = None
|
401
408
|
mrope_positions = self.mrope_positions[:, :bs]
|
409
|
+
self.num_token_non_padded[...] = num_tokens
|
402
410
|
|
403
411
|
# pipeline parallelism
|
404
412
|
if self.pp_size > 1:
|
@@ -457,6 +465,7 @@ class CudaGraphRunner:
|
|
457
465
|
spec_info=spec_info,
|
458
466
|
capture_hidden_mode=self.capture_hidden_mode,
|
459
467
|
lora_paths=lora_paths,
|
468
|
+
num_token_non_padded=self.num_token_non_padded,
|
460
469
|
)
|
461
470
|
|
462
471
|
if lora_paths is not None:
|
@@ -552,6 +561,7 @@ class CudaGraphRunner:
|
|
552
561
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
553
562
|
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
554
563
|
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
564
|
+
self.num_token_non_padded[...] = len(forward_batch.input_ids)
|
555
565
|
if forward_batch.seq_lens_cpu is not None:
|
556
566
|
if bs != raw_bs:
|
557
567
|
self.seq_lens_cpu.fill_(1)
|
@@ -604,6 +614,7 @@ class CudaGraphRunner:
|
|
604
614
|
|
605
615
|
# Replay
|
606
616
|
self.graphs[self.bs].replay()
|
617
|
+
|
607
618
|
output = self.output_buffers[self.bs]
|
608
619
|
if isinstance(output, LogitsProcessorOutput):
|
609
620
|
return LogitsProcessorOutput(
|