sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,158 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from sglang.srt.mem_cache.memory_pool import KVCache
|
11
|
+
|
12
|
+
|
13
|
+
def alloc_extend_kernel_ascend(
|
14
|
+
prefix_lens,
|
15
|
+
seq_lens,
|
16
|
+
last_loc,
|
17
|
+
free_pages,
|
18
|
+
out_indices,
|
19
|
+
page_size,
|
20
|
+
device,
|
21
|
+
):
|
22
|
+
extend_lens = seq_lens - prefix_lens
|
23
|
+
end_pos = torch.cumsum(extend_lens, 0)
|
24
|
+
start_pos = end_pos - extend_lens
|
25
|
+
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
26
|
+
prefix_lens + page_size - 1
|
27
|
+
) // page_size
|
28
|
+
num_full_new_pages = (seq_lens) // page_size - (
|
29
|
+
prefix_lens + page_size - 1
|
30
|
+
) // page_size
|
31
|
+
need_page = num_new_pages - num_full_new_pages
|
32
|
+
end_new_pages = torch.cumsum(num_new_pages, 0)
|
33
|
+
start_new_pages = end_new_pages - num_new_pages
|
34
|
+
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
|
35
|
+
for i in range(len(prefix_lens)):
|
36
|
+
num1 = (
|
37
|
+
min(
|
38
|
+
seq_lens[i],
|
39
|
+
(prefix_lens[i] + page_size - 1) // page_size * page_size,
|
40
|
+
)
|
41
|
+
- prefix_lens[i]
|
42
|
+
)
|
43
|
+
if num1:
|
44
|
+
out_indices[start_pos[i] : start_pos[i] + num1] = (
|
45
|
+
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
|
46
|
+
)
|
47
|
+
|
48
|
+
num2 = (
|
49
|
+
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
|
50
|
+
) * page_size
|
51
|
+
if num2:
|
52
|
+
pages = (
|
53
|
+
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
|
54
|
+
* page_size
|
55
|
+
)
|
56
|
+
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
|
57
|
+
pages.view(-1, 1) + pos_in_page.view(1, -1)
|
58
|
+
).view(-1)
|
59
|
+
|
60
|
+
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
|
61
|
+
if num3:
|
62
|
+
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
63
|
+
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
64
|
+
).view(-1)
|
65
|
+
|
66
|
+
|
67
|
+
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
68
|
+
|
69
|
+
def __init__(
|
70
|
+
self,
|
71
|
+
size: int,
|
72
|
+
page_size: int,
|
73
|
+
dtype: torch.dtype,
|
74
|
+
device: str,
|
75
|
+
kvcache: KVCache,
|
76
|
+
need_sort: bool,
|
77
|
+
):
|
78
|
+
super().__init__(size, page_size, dtype, device, kvcache, need_sort, 1)
|
79
|
+
|
80
|
+
def alloc_extend(
|
81
|
+
self,
|
82
|
+
prefix_lens: torch.Tensor,
|
83
|
+
seq_lens: torch.Tensor,
|
84
|
+
last_loc: torch.Tensor,
|
85
|
+
extend_num_tokens: int,
|
86
|
+
):
|
87
|
+
if self.debug_mode:
|
88
|
+
assert torch.all(
|
89
|
+
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
90
|
+
)
|
91
|
+
|
92
|
+
num_new_pages = (
|
93
|
+
(
|
94
|
+
(seq_lens + self.page_size - 1) // self.page_size
|
95
|
+
- (prefix_lens + self.page_size - 1) // self.page_size
|
96
|
+
)
|
97
|
+
.sum()
|
98
|
+
.item()
|
99
|
+
)
|
100
|
+
if self.need_sort and num_new_pages > len(self.free_pages):
|
101
|
+
self.merge_and_sort_free()
|
102
|
+
|
103
|
+
if num_new_pages > len(self.free_pages):
|
104
|
+
return None
|
105
|
+
|
106
|
+
out_indices = torch.empty(
|
107
|
+
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
108
|
+
)
|
109
|
+
|
110
|
+
alloc_extend_kernel_ascend(
|
111
|
+
prefix_lens,
|
112
|
+
seq_lens,
|
113
|
+
last_loc,
|
114
|
+
self.free_pages,
|
115
|
+
out_indices,
|
116
|
+
self.page_size,
|
117
|
+
self.device,
|
118
|
+
)
|
119
|
+
|
120
|
+
if self.debug_mode:
|
121
|
+
assert len(torch.unique(out_indices)) == len(out_indices)
|
122
|
+
|
123
|
+
self.free_pages = self.free_pages[num_new_pages:]
|
124
|
+
return out_indices
|
125
|
+
|
126
|
+
def alloc_decode(
|
127
|
+
self,
|
128
|
+
seq_lens: torch.Tensor,
|
129
|
+
last_loc: torch.Tensor,
|
130
|
+
):
|
131
|
+
if self.debug_mode:
|
132
|
+
assert torch.all(
|
133
|
+
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
134
|
+
)
|
135
|
+
|
136
|
+
need_new_pages = (seq_lens % self.page_size == 1).int()
|
137
|
+
num_new_pages = need_new_pages.sum().item()
|
138
|
+
|
139
|
+
if num_new_pages > len(self.free_pages):
|
140
|
+
self.merge_and_sort_free()
|
141
|
+
|
142
|
+
if num_new_pages > len(self.free_pages):
|
143
|
+
return None
|
144
|
+
|
145
|
+
end_new_pages = torch.cumsum(need_new_pages, 0)
|
146
|
+
start_new_pages = end_new_pages - need_new_pages
|
147
|
+
if num_new_pages == 0:
|
148
|
+
out_indices = last_loc + 1
|
149
|
+
else:
|
150
|
+
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
|
151
|
+
start_new_pages
|
152
|
+
] * self.page_size * need_new_pages
|
153
|
+
|
154
|
+
if self.debug_mode:
|
155
|
+
assert len(torch.unique(out_indices)) == len(out_indices)
|
156
|
+
|
157
|
+
self.free_pages = self.free_pages[num_new_pages:]
|
158
|
+
return out_indices.int()
|
@@ -71,8 +71,10 @@ class HiRadixCache(RadixCache):
|
|
71
71
|
self.tp_group = tp_cache_group
|
72
72
|
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
73
73
|
self.enable_storage = hicache_storage_backend is not None
|
74
|
-
# todo: customizable storage prefetch threshold
|
74
|
+
# todo: customizable storage prefetch threshold and timeout
|
75
75
|
self.prefetch_threshold = 256
|
76
|
+
self.prefetch_timeout = 3 # seconds
|
77
|
+
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
76
78
|
|
77
79
|
self.load_cache_event = threading.Event()
|
78
80
|
self.cache_controller = HiCacheController(
|
@@ -87,13 +89,6 @@ class HiRadixCache(RadixCache):
|
|
87
89
|
prefetch_threshold=self.prefetch_threshold,
|
88
90
|
)
|
89
91
|
|
90
|
-
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
91
|
-
# todo: customizable storage prefetch timeout
|
92
|
-
self.prefetch_timeout = 3 # seconds
|
93
|
-
logger.info(
|
94
|
-
f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
|
95
|
-
)
|
96
|
-
|
97
92
|
# record the nodes with ongoing write through
|
98
93
|
self.ongoing_write_through = {}
|
99
94
|
# record the node segments with ongoing load back
|
@@ -151,7 +146,7 @@ class HiRadixCache(RadixCache):
|
|
151
146
|
|
152
147
|
def write_backup_storage(self, node: TreeNode):
|
153
148
|
operation_id = self.cache_controller.write_storage(
|
154
|
-
node.host_value, node.key, node.
|
149
|
+
node.host_value, node.key, node.hash_value
|
155
150
|
)
|
156
151
|
self.ongoing_backup[operation_id] = node
|
157
152
|
node.protect_host()
|
@@ -414,18 +409,18 @@ class HiRadixCache(RadixCache):
|
|
414
409
|
group=self.tp_group,
|
415
410
|
)
|
416
411
|
for _ in range(queue_size.item()):
|
417
|
-
ack_id,
|
418
|
-
self.cache_controller.ack_backup_queue.get()
|
419
|
-
)
|
412
|
+
ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
|
420
413
|
host_node = self.ongoing_backup[ack_id]
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
414
|
+
|
415
|
+
if completed_tokens > 0:
|
416
|
+
if completed_tokens < len(host_node.key):
|
417
|
+
# backup is only partially successful, split the node
|
418
|
+
new_node = self._split_node(
|
419
|
+
host_node.key, host_node, completed_tokens
|
420
|
+
)
|
421
|
+
new_node.backuped_storage = True
|
422
|
+
else:
|
423
|
+
host_node.backuped_storage = True
|
429
424
|
host_node.release_host()
|
430
425
|
del self.ongoing_backup[ack_id]
|
431
426
|
|
@@ -471,6 +466,10 @@ class HiRadixCache(RadixCache):
|
|
471
466
|
req_id
|
472
467
|
]
|
473
468
|
|
469
|
+
if operation.host_indices is None:
|
470
|
+
# prefetch has not been issued due to insufficient host memory
|
471
|
+
return True
|
472
|
+
|
474
473
|
if not self.can_terminate_prefetch(operation):
|
475
474
|
return False
|
476
475
|
|
@@ -565,10 +564,6 @@ class HiRadixCache(RadixCache):
|
|
565
564
|
if host_indices is None:
|
566
565
|
self.evict_host(prefetch_length)
|
567
566
|
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
568
|
-
if host_indices is None:
|
569
|
-
last_host_node.release_host()
|
570
|
-
# no sufficient host memory to prefetch
|
571
|
-
return
|
572
567
|
operation = self.cache_controller.prefetch(
|
573
568
|
req_id, host_indices, new_input_tokens, last_hash
|
574
569
|
)
|
@@ -717,6 +712,21 @@ class HiRadixCache(RadixCache):
|
|
717
712
|
node.children[child_key] = new_node
|
718
713
|
self.evictable_size_ += len(value)
|
719
714
|
|
715
|
+
if self.enable_storage:
|
716
|
+
last_hash = node.get_last_hash_value()
|
717
|
+
assert (node == self.root_node) or (
|
718
|
+
last_hash is not None
|
719
|
+
), "Parent node must have a hash value with storage enabled"
|
720
|
+
new_node.hash_value = []
|
721
|
+
for idx in range(0, len(key), self.page_size):
|
722
|
+
new_node.hash_value.append(
|
723
|
+
self.cache_controller.get_hash_str(
|
724
|
+
key[idx : idx + self.page_size],
|
725
|
+
prior_hash=last_hash,
|
726
|
+
)
|
727
|
+
)
|
728
|
+
last_hash = new_node.hash_value[-1]
|
729
|
+
|
720
730
|
if self.cache_controller.write_policy != "write_back":
|
721
731
|
self.inc_hit_count(new_node)
|
722
732
|
return total_prefix_length
|
@@ -0,0 +1,421 @@
|
|
1
|
+
"""Radix cache for LoRA. It's modified based on RadixCache with lora_id added to the key of nodes."""
|
2
|
+
|
3
|
+
import heapq
|
4
|
+
import time
|
5
|
+
from collections import defaultdict
|
6
|
+
from typing import TYPE_CHECKING, Any, List, Optional
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
11
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
12
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
13
|
+
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from sglang.srt.managers.schedule_batch import Req
|
16
|
+
else:
|
17
|
+
Req = Any # Placeholder for Req type when not type checking
|
18
|
+
|
19
|
+
|
20
|
+
class LoRAKey:
|
21
|
+
|
22
|
+
def __init__(self, lora_id: str, token_ids: List[int]):
|
23
|
+
self.lora_id = (
|
24
|
+
lora_id # lora_id of adaptor, should be hash value of adaptor path
|
25
|
+
)
|
26
|
+
self.token_ids = token_ids # token_ids of the key
|
27
|
+
|
28
|
+
def __len__(self):
|
29
|
+
return len(self.token_ids)
|
30
|
+
|
31
|
+
|
32
|
+
def get_child_key(key: LoRAKey):
|
33
|
+
# Here the key of children dict is the hash of lora_id + str(token_ids[0])
|
34
|
+
# So the child key can be matched only when lora_id and token_ids[0] are the same
|
35
|
+
if key.lora_id is None:
|
36
|
+
return hash(str(key.token_ids[0]))
|
37
|
+
else:
|
38
|
+
return hash(key.lora_id + str(key.token_ids[0]))
|
39
|
+
|
40
|
+
|
41
|
+
class LoRATreeNode:
|
42
|
+
|
43
|
+
counter = 0
|
44
|
+
|
45
|
+
def __init__(self, id: Optional[int] = None):
|
46
|
+
self.children = defaultdict(LoRATreeNode)
|
47
|
+
self.parent: LoRATreeNode = None
|
48
|
+
self.key: LoRAKey = None
|
49
|
+
self.value: Optional[torch.Tensor] = None
|
50
|
+
self.lock_ref = 0
|
51
|
+
self.last_access_time = time.monotonic()
|
52
|
+
|
53
|
+
self.id = LoRATreeNode.counter if id is None else id
|
54
|
+
LoRATreeNode.counter += 1
|
55
|
+
|
56
|
+
@property
|
57
|
+
def evicted(self):
|
58
|
+
return self.value is None
|
59
|
+
|
60
|
+
def __lt__(self, other: "LoRATreeNode"):
|
61
|
+
return self.last_access_time < other.last_access_time
|
62
|
+
|
63
|
+
|
64
|
+
def _key_match(key0: LoRAKey, key1: LoRAKey):
|
65
|
+
if key0.lora_id != key1.lora_id:
|
66
|
+
raise ValueError(
|
67
|
+
f"_key_match should be run on the same lora_id, but got key0.lora_id={key0.lora_id} != key1.lora_id={key1.lora_id}"
|
68
|
+
)
|
69
|
+
i = 0
|
70
|
+
for k0, k1 in zip(key0.token_ids, key1.token_ids):
|
71
|
+
if k0 != k1:
|
72
|
+
break
|
73
|
+
i += 1
|
74
|
+
return i
|
75
|
+
|
76
|
+
|
77
|
+
class LoRARadixCache(BasePrefixCache):
|
78
|
+
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
req_to_token_pool: ReqToTokenPool,
|
82
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
83
|
+
page_size: int,
|
84
|
+
disable: bool = False,
|
85
|
+
):
|
86
|
+
if page_size > 1:
|
87
|
+
raise ValueError("LoRARadixCache currently only supports page_size = 1")
|
88
|
+
|
89
|
+
if token_to_kv_pool_allocator is None:
|
90
|
+
raise ValueError(
|
91
|
+
"token_to_kv_pool_allocator is required to run LoraRadixCache"
|
92
|
+
)
|
93
|
+
|
94
|
+
self.req_to_token_pool = req_to_token_pool
|
95
|
+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
96
|
+
self.page_size = page_size
|
97
|
+
self.disable = disable
|
98
|
+
self.device = self.token_to_kv_pool_allocator.device
|
99
|
+
|
100
|
+
self.key_match_fn = _key_match
|
101
|
+
self.get_child_key_fn = get_child_key
|
102
|
+
self.reset()
|
103
|
+
|
104
|
+
def reset(self):
|
105
|
+
self.root_node = LoRATreeNode()
|
106
|
+
self.root_node.key = LoRAKey(lora_id="", token_ids=[])
|
107
|
+
self.root_node.value = None
|
108
|
+
self.evictable_size_ = 0
|
109
|
+
self.protected_size_ = 0
|
110
|
+
|
111
|
+
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
112
|
+
raise ValueError(
|
113
|
+
"LoRARadixCache needs both token ids and lora id as inputs for matching. Please use match_prefix_with_lora_id instead."
|
114
|
+
)
|
115
|
+
|
116
|
+
def match_prefix_with_lora_id(self, key: LoRAKey, **kwargs) -> MatchResult:
|
117
|
+
"""Find the matching prefix from the lora radix tree.
|
118
|
+
Args:
|
119
|
+
key: A LoRAKey to find a matching prefix.
|
120
|
+
Returns:
|
121
|
+
A tuple of a tensor of matching prefix token IDs and
|
122
|
+
the last node that contains the prefix values. Note that
|
123
|
+
this API can modify the internal state of the Radix tree.
|
124
|
+
The last node create a new child if the prefix is shorter
|
125
|
+
than the last node's value.
|
126
|
+
"""
|
127
|
+
if self.disable or len(key) == 0:
|
128
|
+
return MatchResult(
|
129
|
+
device_indices=torch.empty(
|
130
|
+
(0,),
|
131
|
+
dtype=torch.int64,
|
132
|
+
device=self.device,
|
133
|
+
),
|
134
|
+
last_device_node=self.root_node,
|
135
|
+
last_host_node=self.root_node,
|
136
|
+
)
|
137
|
+
|
138
|
+
value, last_node = self._match_prefix_helper(self.root_node, key)
|
139
|
+
if value:
|
140
|
+
value = torch.cat(value)
|
141
|
+
else:
|
142
|
+
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
143
|
+
return MatchResult(
|
144
|
+
device_indices=value,
|
145
|
+
last_device_node=last_node,
|
146
|
+
last_host_node=last_node,
|
147
|
+
)
|
148
|
+
|
149
|
+
def insert(self, key: LoRAKey, value=None):
|
150
|
+
if self.disable:
|
151
|
+
return 0
|
152
|
+
|
153
|
+
if value is None:
|
154
|
+
value = [x for x in key.token_ids]
|
155
|
+
return self._insert_helper(self.root_node, key, value)
|
156
|
+
|
157
|
+
def cache_finished_req(self, req: Req):
|
158
|
+
"""Cache request when it finishes."""
|
159
|
+
if self.disable:
|
160
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
161
|
+
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
162
|
+
]
|
163
|
+
self.token_to_kv_pool_allocator.free(kv_indices)
|
164
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
165
|
+
return
|
166
|
+
|
167
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
168
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
169
|
+
req.req_pool_idx, : len(token_ids)
|
170
|
+
]
|
171
|
+
|
172
|
+
page_aligned_len = len(kv_indices)
|
173
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
174
|
+
|
175
|
+
# Radix Cache takes one ref in memory pool
|
176
|
+
lora_key = LoRAKey(lora_id=req.lora_id, token_ids=token_ids[:page_aligned_len])
|
177
|
+
new_prefix_len = self.insert(lora_key, page_aligned_kv_indices)
|
178
|
+
self.token_to_kv_pool_allocator.free(
|
179
|
+
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
180
|
+
)
|
181
|
+
|
182
|
+
# Remove req slot release the cache lock
|
183
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
184
|
+
self.dec_lock_ref(req.last_node)
|
185
|
+
|
186
|
+
def cache_unfinished_req(self, req: Req):
|
187
|
+
"""Cache request when it is unfinished."""
|
188
|
+
if self.disable:
|
189
|
+
return
|
190
|
+
|
191
|
+
token_ids = req.fill_ids
|
192
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
193
|
+
req.req_pool_idx, : len(token_ids)
|
194
|
+
]
|
195
|
+
|
196
|
+
page_aligned_len = len(kv_indices)
|
197
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
198
|
+
page_aligned_token_ids = token_ids[:page_aligned_len]
|
199
|
+
|
200
|
+
# Radix Cache takes one ref in memory pool
|
201
|
+
inserted_key = LoRAKey(lora_id=req.lora_id, token_ids=page_aligned_token_ids)
|
202
|
+
new_prefix_len = self.insert(inserted_key, page_aligned_kv_indices)
|
203
|
+
self.token_to_kv_pool_allocator.free(
|
204
|
+
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
205
|
+
)
|
206
|
+
|
207
|
+
# The prefix indices could be updated, reuse it
|
208
|
+
new_indices, new_last_node, _, _ = self.match_prefix_with_lora_id(inserted_key)
|
209
|
+
self.req_to_token_pool.write(
|
210
|
+
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
211
|
+
new_indices[len(req.prefix_indices) :],
|
212
|
+
)
|
213
|
+
|
214
|
+
self.dec_lock_ref(req.last_node)
|
215
|
+
self.inc_lock_ref(new_last_node)
|
216
|
+
|
217
|
+
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
218
|
+
req.prefix_indices = new_indices
|
219
|
+
req.last_node = new_last_node
|
220
|
+
|
221
|
+
def pretty_print(self):
|
222
|
+
self._print_helper(self.root_node, 0)
|
223
|
+
print(f"#tokens: {self.total_size()}")
|
224
|
+
|
225
|
+
def total_size(self):
|
226
|
+
return self._total_size_helper()
|
227
|
+
|
228
|
+
def evict(self, num_tokens: int):
|
229
|
+
if self.disable:
|
230
|
+
return
|
231
|
+
|
232
|
+
leaves = self._collect_leaves()
|
233
|
+
heapq.heapify(leaves)
|
234
|
+
|
235
|
+
num_evicted = 0
|
236
|
+
while num_evicted < num_tokens and len(leaves):
|
237
|
+
x = heapq.heappop(leaves)
|
238
|
+
|
239
|
+
if x == self.root_node:
|
240
|
+
break
|
241
|
+
if x.lock_ref > 0:
|
242
|
+
continue
|
243
|
+
|
244
|
+
self.token_to_kv_pool_allocator.free(x.value)
|
245
|
+
num_evicted += len(x.value)
|
246
|
+
self._delete_leaf(x)
|
247
|
+
|
248
|
+
if len(x.parent.children) == 0:
|
249
|
+
heapq.heappush(leaves, x.parent)
|
250
|
+
|
251
|
+
def inc_lock_ref(self, node: LoRATreeNode):
|
252
|
+
if self.disable:
|
253
|
+
return 0
|
254
|
+
|
255
|
+
delta = 0
|
256
|
+
while node != self.root_node:
|
257
|
+
if node.lock_ref == 0:
|
258
|
+
self.evictable_size_ -= len(node.value)
|
259
|
+
self.protected_size_ += len(node.value)
|
260
|
+
delta -= len(node.value)
|
261
|
+
node.lock_ref += 1
|
262
|
+
node = node.parent
|
263
|
+
return delta
|
264
|
+
|
265
|
+
def dec_lock_ref(self, node: LoRATreeNode):
|
266
|
+
if self.disable:
|
267
|
+
return 0
|
268
|
+
|
269
|
+
delta = 0
|
270
|
+
while node != self.root_node:
|
271
|
+
if node.lock_ref == 1:
|
272
|
+
self.evictable_size_ += len(node.value)
|
273
|
+
self.protected_size_ -= len(node.value)
|
274
|
+
delta += len(node.value)
|
275
|
+
node.lock_ref -= 1
|
276
|
+
node = node.parent
|
277
|
+
return delta
|
278
|
+
|
279
|
+
def evictable_size(self):
|
280
|
+
return self.evictable_size_
|
281
|
+
|
282
|
+
def protected_size(self):
|
283
|
+
# protected size refers to the size of the cache that is locked
|
284
|
+
return self.protected_size_
|
285
|
+
|
286
|
+
def all_values_flatten(self):
|
287
|
+
values = []
|
288
|
+
|
289
|
+
def _dfs_helper(node: LoRATreeNode):
|
290
|
+
for _, child in node.children.items():
|
291
|
+
values.append(child.value)
|
292
|
+
_dfs_helper(child)
|
293
|
+
|
294
|
+
_dfs_helper(self.root_node)
|
295
|
+
return torch.cat(values)
|
296
|
+
|
297
|
+
##### Internal Helper Functions #####
|
298
|
+
|
299
|
+
def _match_prefix_helper(self, node: LoRATreeNode, key: LoRAKey):
|
300
|
+
node.last_access_time = time.monotonic()
|
301
|
+
|
302
|
+
child_key = self.get_child_key_fn(key)
|
303
|
+
|
304
|
+
value = []
|
305
|
+
while len(key) > 0 and child_key in node.children.keys():
|
306
|
+
child = node.children[child_key]
|
307
|
+
child.last_access_time = time.monotonic()
|
308
|
+
prefix_len = self.key_match_fn(child.key, key)
|
309
|
+
if prefix_len < len(child.key):
|
310
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
311
|
+
value.append(new_node.value)
|
312
|
+
node = new_node
|
313
|
+
break
|
314
|
+
else:
|
315
|
+
value.append(child.value)
|
316
|
+
node = child
|
317
|
+
key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
|
318
|
+
|
319
|
+
if len(key):
|
320
|
+
child_key = self.get_child_key_fn(key)
|
321
|
+
|
322
|
+
return value, node
|
323
|
+
|
324
|
+
def _split_node(self, key: LoRAKey, child: LoRATreeNode, split_len: int):
|
325
|
+
# new_node -> child
|
326
|
+
new_node = LoRATreeNode()
|
327
|
+
key_split_1 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[:split_len])
|
328
|
+
key_split_2 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[split_len:])
|
329
|
+
new_node.children = {self.get_child_key_fn(key_split_2): child}
|
330
|
+
new_node.parent = child.parent
|
331
|
+
new_node.lock_ref = child.lock_ref
|
332
|
+
new_node.key = key_split_1
|
333
|
+
new_node.value = child.value[:split_len]
|
334
|
+
child.parent = new_node
|
335
|
+
child.key = key_split_2
|
336
|
+
child.value = child.value[split_len:]
|
337
|
+
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
338
|
+
|
339
|
+
return new_node
|
340
|
+
|
341
|
+
def _insert_helper(self, node: LoRATreeNode, key: LoRAKey, value):
|
342
|
+
node.last_access_time = time.monotonic()
|
343
|
+
if len(key) == 0:
|
344
|
+
return 0
|
345
|
+
|
346
|
+
child_key = self.get_child_key_fn(key)
|
347
|
+
|
348
|
+
total_prefix_length = 0
|
349
|
+
while len(key) > 0 and child_key in node.children.keys():
|
350
|
+
node = node.children[child_key]
|
351
|
+
node.last_access_time = time.monotonic()
|
352
|
+
prefix_len = self.key_match_fn(node.key, key)
|
353
|
+
total_prefix_length += prefix_len
|
354
|
+
key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
|
355
|
+
value = value[prefix_len:]
|
356
|
+
|
357
|
+
if prefix_len < len(node.key):
|
358
|
+
new_node = self._split_node(node.key, node, prefix_len)
|
359
|
+
node = new_node
|
360
|
+
|
361
|
+
if len(key):
|
362
|
+
child_key = self.get_child_key_fn(key)
|
363
|
+
|
364
|
+
if len(key):
|
365
|
+
new_node = LoRATreeNode()
|
366
|
+
new_node.parent = node
|
367
|
+
new_node.key = key
|
368
|
+
new_node.value = value
|
369
|
+
node.children[child_key] = new_node
|
370
|
+
self.evictable_size_ += len(value)
|
371
|
+
return total_prefix_length
|
372
|
+
|
373
|
+
def _print_helper(self, node: LoRATreeNode, indent: int):
|
374
|
+
"""Prints the radix tree in a human-readable format."""
|
375
|
+
stack = [(node, indent)]
|
376
|
+
while stack:
|
377
|
+
current_node, current_indent = stack.pop()
|
378
|
+
print(
|
379
|
+
" " * current_indent,
|
380
|
+
len(current_node.key),
|
381
|
+
current_node.key.token_ids[:10],
|
382
|
+
f"r={current_node.lock_ref}",
|
383
|
+
)
|
384
|
+
for key, child in current_node.children.items():
|
385
|
+
stack.append((child, current_indent + 2))
|
386
|
+
|
387
|
+
assert key == self.get_child_key_fn(
|
388
|
+
child.key
|
389
|
+
), f"{key=}, {self.get_child_key_fn(child.key)=}"
|
390
|
+
|
391
|
+
def _delete_leaf(self, node):
|
392
|
+
for k, v in node.parent.children.items():
|
393
|
+
if v == node:
|
394
|
+
break
|
395
|
+
del node.parent.children[k]
|
396
|
+
self.evictable_size_ -= len(node.key)
|
397
|
+
|
398
|
+
def _total_size_helper(self):
|
399
|
+
total_size = 0
|
400
|
+
stack = [self.root_node]
|
401
|
+
while stack:
|
402
|
+
current_node = stack.pop()
|
403
|
+
total_size += len(current_node.value)
|
404
|
+
for child in current_node.children.values():
|
405
|
+
if child.evicted:
|
406
|
+
continue
|
407
|
+
stack.append(child)
|
408
|
+
return total_size
|
409
|
+
|
410
|
+
def _collect_leaves(self):
|
411
|
+
ret_list = []
|
412
|
+
stack = [self.root_node]
|
413
|
+
|
414
|
+
while stack:
|
415
|
+
cur_node = stack.pop()
|
416
|
+
if len(cur_node.children) == 0:
|
417
|
+
ret_list.append(cur_node)
|
418
|
+
else:
|
419
|
+
stack.extend(cur_node.children.values())
|
420
|
+
|
421
|
+
return ret_list
|