sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +220 -378
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +208 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -100,7 +100,7 @@ class TpModelWorkerClient:
|
|
100
100
|
def get_memory_pool(self):
|
101
101
|
return (
|
102
102
|
self.worker.model_runner.req_to_token_pool,
|
103
|
-
self.worker.model_runner.
|
103
|
+
self.worker.model_runner.token_to_kv_pool_allocator,
|
104
104
|
)
|
105
105
|
|
106
106
|
def forward_thread_func(self):
|
@@ -175,7 +175,7 @@ class TpModelWorkerClient:
|
|
175
175
|
logits_output.next_token_logprobs.tolist()
|
176
176
|
)
|
177
177
|
if logits_output.input_token_logprobs is not None:
|
178
|
-
logits_output.input_token_logprobs = (
|
178
|
+
logits_output.input_token_logprobs = tuple(
|
179
179
|
logits_output.input_token_logprobs.tolist()
|
180
180
|
)
|
181
181
|
next_token_ids = next_token_ids.tolist()
|
@@ -188,8 +188,7 @@ class TpModelWorkerClient:
|
|
188
188
|
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
|
189
189
|
sampling_info,
|
190
190
|
sampling_info_done=threading.Event(),
|
191
|
-
|
192
|
-
linear_penalties=sampling_info.linear_penalties,
|
191
|
+
penalizer_orchestrator=None,
|
193
192
|
)
|
194
193
|
|
195
194
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
@@ -1,29 +1,33 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
4
|
+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
4
5
|
|
5
|
-
|
6
|
+
import torch
|
6
7
|
|
7
8
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
8
|
-
from sglang.srt.mem_cache.memory_pool import
|
9
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
9
10
|
|
10
11
|
if TYPE_CHECKING:
|
11
12
|
from sglang.srt.managers.schedule_batch import Req
|
12
13
|
|
13
14
|
|
14
15
|
class ChunkCacheEntry:
|
15
|
-
def __init__(self, rid, value):
|
16
|
+
def __init__(self, rid: str, value: torch.Tensor):
|
16
17
|
self.rid = rid
|
17
18
|
self.value = value
|
18
19
|
|
19
20
|
|
20
21
|
class ChunkCache(BasePrefixCache):
|
21
22
|
def __init__(
|
22
|
-
self,
|
23
|
+
self,
|
24
|
+
req_to_token_pool: ReqToTokenPool,
|
25
|
+
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
23
26
|
):
|
24
27
|
self.disable = True
|
25
28
|
self.req_to_token_pool = req_to_token_pool
|
26
|
-
self.
|
29
|
+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
30
|
+
self.entries: Dict[str, ChunkCacheEntry] = {}
|
27
31
|
|
28
32
|
self.reset()
|
29
33
|
|
@@ -48,16 +52,13 @@ class ChunkCache(BasePrefixCache):
|
|
48
52
|
req.req_pool_idx, :token_id_len
|
49
53
|
]
|
50
54
|
self.req_to_token_pool.free(req.req_pool_idx)
|
51
|
-
self.
|
55
|
+
self.token_to_kv_pool_allocator.free(kv_indices)
|
52
56
|
|
53
57
|
if req.rid in self.entries:
|
54
58
|
del self.entries[req.rid]
|
55
59
|
|
56
|
-
def cache_unfinished_req(self, req: Req
|
57
|
-
|
58
|
-
token_id_len = len(req.fill_ids)
|
59
|
-
else:
|
60
|
-
token_id_len = len(token_ids)
|
60
|
+
def cache_unfinished_req(self, req: Req):
|
61
|
+
token_id_len = len(req.fill_ids)
|
61
62
|
|
62
63
|
kv_indices = self.req_to_token_pool.req_to_token[
|
63
64
|
req.req_pool_idx, :token_id_len
|
@@ -86,5 +87,11 @@ class ChunkCache(BasePrefixCache):
|
|
86
87
|
def evictable_size(self):
|
87
88
|
return 0
|
88
89
|
|
90
|
+
def pretty_print(self):
|
91
|
+
return ""
|
92
|
+
|
89
93
|
def protected_size(self):
|
90
94
|
return 0
|
95
|
+
|
96
|
+
def pretty_print(self):
|
97
|
+
return ""
|
@@ -0,0 +1,394 @@
|
|
1
|
+
import heapq
|
2
|
+
import logging
|
3
|
+
import time
|
4
|
+
from typing import List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.managers.cache_controller import HiCacheController
|
9
|
+
from sglang.srt.mem_cache.memory_pool import (
|
10
|
+
MHATokenToKVPool,
|
11
|
+
MHATokenToKVPoolHost,
|
12
|
+
ReqToTokenPool,
|
13
|
+
)
|
14
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class HiRadixCache(RadixCache):
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
req_to_token_pool: ReqToTokenPool,
|
24
|
+
token_to_kv_pool: MHATokenToKVPool,
|
25
|
+
):
|
26
|
+
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
|
27
|
+
self.cache_controller = HiCacheController(
|
28
|
+
token_to_kv_pool, self.token_to_kv_pool_host
|
29
|
+
)
|
30
|
+
|
31
|
+
# record the nodes with ongoing write through
|
32
|
+
self.ongoing_write_through = {}
|
33
|
+
# record the node segments with ongoing load back
|
34
|
+
self.ongoing_load_back = {}
|
35
|
+
# todo: dynamically adjust the threshold
|
36
|
+
self.write_through_threshold = 1
|
37
|
+
self.load_back_threshold = 10
|
38
|
+
super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
|
39
|
+
|
40
|
+
def reset(self):
|
41
|
+
TreeNode.counter = 0
|
42
|
+
self.cache_controller.reset()
|
43
|
+
self.token_to_kv_pool_host.clear()
|
44
|
+
super().reset()
|
45
|
+
|
46
|
+
def get_height(self, node: TreeNode):
|
47
|
+
height = 0
|
48
|
+
while node != self.root_node:
|
49
|
+
node = node.parent
|
50
|
+
height += 1
|
51
|
+
return height
|
52
|
+
|
53
|
+
def write_backup(self, node: TreeNode):
|
54
|
+
host_indices = self.cache_controller.write(
|
55
|
+
device_indices=node.value,
|
56
|
+
priority=-self.get_height(node),
|
57
|
+
node_id=node.id,
|
58
|
+
)
|
59
|
+
if host_indices is None:
|
60
|
+
self.evict_host(len(node.value))
|
61
|
+
host_indices = self.cache_controller.write(
|
62
|
+
device_indices=node.value,
|
63
|
+
priority=-self.get_height(node),
|
64
|
+
node_id=node.id,
|
65
|
+
)
|
66
|
+
if host_indices is not None:
|
67
|
+
node.host_value = host_indices
|
68
|
+
self.ongoing_write_through[node.id] = node
|
69
|
+
self.inc_lock_ref(node)
|
70
|
+
else:
|
71
|
+
return None
|
72
|
+
|
73
|
+
return len(host_indices)
|
74
|
+
|
75
|
+
def inc_hit_count(self, node: TreeNode):
|
76
|
+
if self.cache_controller.write_policy != "write_through_selective":
|
77
|
+
return
|
78
|
+
node.hit_count += 1
|
79
|
+
if node.host_value is None and node.hit_count > self.write_through_threshold:
|
80
|
+
self.write_backup(node)
|
81
|
+
node.hit_count = 0
|
82
|
+
|
83
|
+
def writing_check(self):
|
84
|
+
while not self.cache_controller.ack_write_queue.empty():
|
85
|
+
try:
|
86
|
+
ack_id = self.cache_controller.ack_write_queue.get_nowait()
|
87
|
+
self.dec_lock_ref(self.ongoing_write_through[ack_id])
|
88
|
+
# clear the reference
|
89
|
+
del self.ongoing_write_through[ack_id]
|
90
|
+
except Exception:
|
91
|
+
break
|
92
|
+
|
93
|
+
def loading_check(self):
|
94
|
+
while not self.cache_controller.ack_load_queue.empty():
|
95
|
+
try:
|
96
|
+
ack_id = self.cache_controller.ack_load_queue.get_nowait()
|
97
|
+
start_node, end_node = self.ongoing_load_back[ack_id]
|
98
|
+
self.dec_lock_ref(end_node)
|
99
|
+
while end_node != start_node:
|
100
|
+
assert end_node.loading
|
101
|
+
end_node.loading = False
|
102
|
+
end_node = end_node.parent
|
103
|
+
# clear the reference
|
104
|
+
del self.ongoing_load_back[ack_id]
|
105
|
+
except Exception:
|
106
|
+
break
|
107
|
+
|
108
|
+
def evictable_size(self):
|
109
|
+
self.writing_check()
|
110
|
+
self.loading_check()
|
111
|
+
return self.evictable_size_
|
112
|
+
|
113
|
+
def evict(self, num_tokens: int, evict_callback=None):
|
114
|
+
leaves = self._collect_leaves_device()
|
115
|
+
heapq.heapify(leaves)
|
116
|
+
|
117
|
+
num_evicted = 0
|
118
|
+
pending_nodes = []
|
119
|
+
while num_evicted < num_tokens and len(leaves):
|
120
|
+
x = heapq.heappop(leaves)
|
121
|
+
|
122
|
+
if x.lock_ref > 0:
|
123
|
+
continue
|
124
|
+
|
125
|
+
if x.host_value is None:
|
126
|
+
if self.cache_controller.write_policy == "write_back":
|
127
|
+
num_evicted += self.write_backup(x)
|
128
|
+
elif self.cache_controller.write_policy == "write_through_selective":
|
129
|
+
num_evicted += self._evict_write_through_selective(x)
|
130
|
+
else:
|
131
|
+
assert (
|
132
|
+
self.cache_controller.write_policy != "write_through"
|
133
|
+
), "write_through should be inclusive"
|
134
|
+
raise NotImplementedError
|
135
|
+
else:
|
136
|
+
num_evicted += self._evict_write_through(x)
|
137
|
+
|
138
|
+
for child in x.parent.children.values():
|
139
|
+
if child in pending_nodes:
|
140
|
+
continue
|
141
|
+
if not child.evicted:
|
142
|
+
break
|
143
|
+
else:
|
144
|
+
# all children are evicted or no children
|
145
|
+
heapq.heappush(leaves, x.parent)
|
146
|
+
|
147
|
+
if self.cache_controller.write_policy == "write_back":
|
148
|
+
# blocking till all write back complete
|
149
|
+
while len(self.ongoing_write_through) > 0:
|
150
|
+
self.writing_check()
|
151
|
+
time.sleep(0.1)
|
152
|
+
|
153
|
+
def _evict_write_through(self, node: TreeNode):
|
154
|
+
# evict a node already written to host
|
155
|
+
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
|
156
|
+
assert num_evicted > 0
|
157
|
+
self.evictable_size_ -= num_evicted
|
158
|
+
node.value = None
|
159
|
+
return num_evicted
|
160
|
+
|
161
|
+
def _evict_write_through_selective(self, node: TreeNode):
|
162
|
+
# evict a node not initiated write to host
|
163
|
+
self.cache_controller.mem_pool_device.free(node.value)
|
164
|
+
num_evicted = len(node.value)
|
165
|
+
self._delete_leaf(node)
|
166
|
+
return num_evicted
|
167
|
+
|
168
|
+
def evict_host(self, num_tokens: int):
|
169
|
+
leaves = self._collect_leaves()
|
170
|
+
heapq.heapify(leaves)
|
171
|
+
|
172
|
+
num_evicted = 0
|
173
|
+
while num_evicted < num_tokens and len(leaves):
|
174
|
+
x = heapq.heappop(leaves)
|
175
|
+
if x == self.root_node:
|
176
|
+
break
|
177
|
+
# only evict the host value of evicted nodes
|
178
|
+
if not x.evicted:
|
179
|
+
continue
|
180
|
+
assert x.lock_ref == 0 and x.host_value is not None
|
181
|
+
|
182
|
+
assert self.cache_controller.evict_host(x.host_value) > 0
|
183
|
+
for k, v in x.parent.children.items():
|
184
|
+
if v == x:
|
185
|
+
break
|
186
|
+
del x.parent.children[k]
|
187
|
+
|
188
|
+
if len(x.parent.children) == 0 and x.parent.evicted:
|
189
|
+
heapq.heappush(leaves, x.parent)
|
190
|
+
|
191
|
+
def load_back(
|
192
|
+
self, node: TreeNode, mem_quota: Optional[int] = None
|
193
|
+
) -> Optional[torch.Tensor]:
|
194
|
+
# todo: more loading policies
|
195
|
+
|
196
|
+
last_hit_node = node
|
197
|
+
nodes_to_load = []
|
198
|
+
while node.evicted:
|
199
|
+
assert (
|
200
|
+
node.backuped
|
201
|
+
), "No backup available on evicted nodes, should not happen"
|
202
|
+
nodes_to_load.insert(0, node)
|
203
|
+
node = node.parent
|
204
|
+
else:
|
205
|
+
ancester_node = node
|
206
|
+
|
207
|
+
# protect the ancestor nodes from eviction
|
208
|
+
delta = self.inc_lock_ref(ancester_node)
|
209
|
+
|
210
|
+
# load it all or not at all
|
211
|
+
host_indices = torch.cat([n.host_value for n in nodes_to_load])
|
212
|
+
if len(host_indices) < self.load_back_threshold or (
|
213
|
+
len(host_indices) > mem_quota + delta if mem_quota is not None else False
|
214
|
+
):
|
215
|
+
# skip loading back if the total size is too small or exceeding the memory quota
|
216
|
+
self.dec_lock_ref(ancester_node)
|
217
|
+
return None
|
218
|
+
|
219
|
+
device_indices = self.cache_controller.load(
|
220
|
+
host_indices=host_indices, node_id=last_hit_node.id
|
221
|
+
)
|
222
|
+
if device_indices is None:
|
223
|
+
self.evict(len(host_indices))
|
224
|
+
device_indices = self.cache_controller.load(
|
225
|
+
host_indices=host_indices, node_id=last_hit_node.id
|
226
|
+
)
|
227
|
+
self.dec_lock_ref(ancester_node)
|
228
|
+
if device_indices is None:
|
229
|
+
# no sufficient GPU memory to load back KV caches
|
230
|
+
return None
|
231
|
+
|
232
|
+
self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
|
233
|
+
offset = 0
|
234
|
+
for node in nodes_to_load:
|
235
|
+
node.value = device_indices[offset : offset + len(node.host_value)]
|
236
|
+
offset += len(node.host_value)
|
237
|
+
node.loading = True
|
238
|
+
self.evictable_size_ += len(device_indices)
|
239
|
+
self.inc_lock_ref(last_hit_node)
|
240
|
+
|
241
|
+
return device_indices
|
242
|
+
|
243
|
+
def loading_complete(self, node: TreeNode):
|
244
|
+
self.loading_check()
|
245
|
+
return node.loading == False
|
246
|
+
|
247
|
+
def init_load_back(
|
248
|
+
self,
|
249
|
+
last_node: TreeNode,
|
250
|
+
prefix_indices: torch.Tensor,
|
251
|
+
mem_quota: Optional[int] = None,
|
252
|
+
):
|
253
|
+
assert (
|
254
|
+
len(prefix_indices) == 0 or prefix_indices.is_cuda
|
255
|
+
), "indices of device kV caches should be on GPU"
|
256
|
+
if last_node.evicted:
|
257
|
+
loading_values = self.load_back(last_node, mem_quota)
|
258
|
+
if loading_values is not None:
|
259
|
+
prefix_indices = (
|
260
|
+
loading_values
|
261
|
+
if len(prefix_indices) == 0
|
262
|
+
else torch.cat([prefix_indices, loading_values])
|
263
|
+
)
|
264
|
+
logger.debug(
|
265
|
+
f"loading back {len(loading_values)} tokens for node {last_node.id}"
|
266
|
+
)
|
267
|
+
|
268
|
+
while last_node.evicted:
|
269
|
+
last_node = last_node.parent
|
270
|
+
|
271
|
+
return last_node, prefix_indices
|
272
|
+
|
273
|
+
def _match_prefix_helper(
|
274
|
+
self, node: TreeNode, key: List, value, last_node: TreeNode
|
275
|
+
):
|
276
|
+
node.last_access_time = time.time()
|
277
|
+
if len(key) == 0:
|
278
|
+
return
|
279
|
+
|
280
|
+
if key[0] in node.children.keys():
|
281
|
+
child = node.children[key[0]]
|
282
|
+
prefix_len = _key_match(child.key, key)
|
283
|
+
if prefix_len < len(child.key):
|
284
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
285
|
+
self.inc_hit_count(new_node)
|
286
|
+
if not new_node.evicted:
|
287
|
+
value.append(new_node.value)
|
288
|
+
last_node[0] = new_node
|
289
|
+
else:
|
290
|
+
self.inc_hit_count(child)
|
291
|
+
if not child.evicted:
|
292
|
+
value.append(child.value)
|
293
|
+
last_node[0] = child
|
294
|
+
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
295
|
+
|
296
|
+
def _split_node(self, key, child: TreeNode, split_len: int):
|
297
|
+
# child node split into new_node -> child
|
298
|
+
new_node = TreeNode()
|
299
|
+
new_node.children = {key[split_len]: child}
|
300
|
+
new_node.parent = child.parent
|
301
|
+
new_node.lock_ref = child.lock_ref
|
302
|
+
new_node.key = child.key[:split_len]
|
303
|
+
new_node.loading = child.loading
|
304
|
+
|
305
|
+
# split value and host value if exists
|
306
|
+
if child.evicted:
|
307
|
+
new_node.value = None
|
308
|
+
else:
|
309
|
+
new_node.value = child.value[:split_len]
|
310
|
+
child.value = child.value[split_len:]
|
311
|
+
if child.host_value is not None:
|
312
|
+
new_node.host_value = child.host_value[:split_len]
|
313
|
+
child.host_value = child.host_value[split_len:]
|
314
|
+
child.parent = new_node
|
315
|
+
child.key = child.key[split_len:]
|
316
|
+
new_node.parent.children[key[0]] = new_node
|
317
|
+
return new_node
|
318
|
+
|
319
|
+
def _insert_helper(self, node: TreeNode, key: List, value):
|
320
|
+
node.last_access_time = time.time()
|
321
|
+
if len(key) == 0:
|
322
|
+
return 0
|
323
|
+
|
324
|
+
if key[0] in node.children.keys():
|
325
|
+
child = node.children[key[0]]
|
326
|
+
prefix_len = _key_match(child.key, key)
|
327
|
+
|
328
|
+
if prefix_len == len(child.key):
|
329
|
+
if child.evicted:
|
330
|
+
# change the reference if the node is evicted
|
331
|
+
# this often happens in the case of KV cache recomputation
|
332
|
+
child.value = value[:prefix_len]
|
333
|
+
self.token_to_kv_pool_host.update_synced(child.host_value)
|
334
|
+
self.evictable_size_ += len(value[:prefix_len])
|
335
|
+
return self._insert_helper(
|
336
|
+
child, key[prefix_len:], value[prefix_len:]
|
337
|
+
)
|
338
|
+
else:
|
339
|
+
self.inc_hit_count(child)
|
340
|
+
return prefix_len + self._insert_helper(
|
341
|
+
child, key[prefix_len:], value[prefix_len:]
|
342
|
+
)
|
343
|
+
|
344
|
+
# partial match, split the node
|
345
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
346
|
+
if new_node.evicted:
|
347
|
+
new_node.value = value[:prefix_len]
|
348
|
+
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
349
|
+
self.evictable_size_ += len(new_node.value)
|
350
|
+
return self._insert_helper(
|
351
|
+
new_node, key[prefix_len:], value[prefix_len:]
|
352
|
+
)
|
353
|
+
else:
|
354
|
+
self.inc_hit_count(new_node)
|
355
|
+
return prefix_len + self._insert_helper(
|
356
|
+
new_node, key[prefix_len:], value[prefix_len:]
|
357
|
+
)
|
358
|
+
|
359
|
+
if len(key):
|
360
|
+
new_node = TreeNode()
|
361
|
+
new_node.parent = node
|
362
|
+
new_node.key = key
|
363
|
+
new_node.value = value
|
364
|
+
node.children[key[0]] = new_node
|
365
|
+
self.evictable_size_ += len(value)
|
366
|
+
|
367
|
+
if self.cache_controller.write_policy == "write_through":
|
368
|
+
self.write_backup(new_node)
|
369
|
+
return 0
|
370
|
+
|
371
|
+
def _collect_leaves_device(self):
|
372
|
+
def is_leaf(node):
|
373
|
+
if node.evicted:
|
374
|
+
return False
|
375
|
+
if node == self.root_node:
|
376
|
+
return False
|
377
|
+
if len(node.children) == 0:
|
378
|
+
return True
|
379
|
+
for child in node.children.values():
|
380
|
+
if not child.evicted:
|
381
|
+
return False
|
382
|
+
return True
|
383
|
+
|
384
|
+
ret_list = []
|
385
|
+
stack = [self.root_node]
|
386
|
+
while stack:
|
387
|
+
cur_node = stack.pop()
|
388
|
+
if is_leaf(cur_node):
|
389
|
+
ret_list.append(cur_node)
|
390
|
+
else:
|
391
|
+
for cur_child in cur_node.children.values():
|
392
|
+
if not cur_child.evicted:
|
393
|
+
stack.append(cur_child)
|
394
|
+
return ret_list
|
@@ -20,9 +20,12 @@ Memory pool.
|
|
20
20
|
|
21
21
|
SGLang has two levels of memory pool.
|
22
22
|
ReqToTokenPool maps a a request to its token locations.
|
23
|
-
|
23
|
+
TokenToKVPoolAllocator maps a token location to its KV cache data.
|
24
|
+
KVCache actually holds the physical kv cache. Allocation indices are allocated
|
25
|
+
by TokenToKVPoolAllocator
|
24
26
|
"""
|
25
27
|
|
28
|
+
import abc
|
26
29
|
import logging
|
27
30
|
import threading
|
28
31
|
from enum import IntEnum
|
@@ -89,7 +92,7 @@ class ReqToTokenPool:
|
|
89
92
|
self.free_slots = list(range(self.size))
|
90
93
|
|
91
94
|
|
92
|
-
class
|
95
|
+
class TokenToKVPoolAllocator:
|
93
96
|
"""A memory pool that maps a token location to its kv cache data."""
|
94
97
|
|
95
98
|
def __init__(
|
@@ -100,11 +103,6 @@ class BaseTokenToKVPool:
|
|
100
103
|
):
|
101
104
|
self.size = size
|
102
105
|
self.dtype = dtype
|
103
|
-
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
104
|
-
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
105
|
-
self.store_dtype = torch.uint8
|
106
|
-
else:
|
107
|
-
self.store_dtype = dtype
|
108
106
|
self.device = device
|
109
107
|
|
110
108
|
self.free_slots = None
|
@@ -148,15 +146,22 @@ class BaseTokenToKVPool:
|
|
148
146
|
self.is_in_free_group = False
|
149
147
|
self.free_group = []
|
150
148
|
|
149
|
+
|
150
|
+
class KVCache(abc.ABC):
|
151
|
+
|
152
|
+
@abc.abstractmethod
|
151
153
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
152
154
|
raise NotImplementedError()
|
153
155
|
|
156
|
+
@abc.abstractmethod
|
154
157
|
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
155
158
|
raise NotImplementedError()
|
156
159
|
|
160
|
+
@abc.abstractmethod
|
157
161
|
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
158
162
|
raise NotImplementedError()
|
159
163
|
|
164
|
+
@abc.abstractmethod
|
160
165
|
def set_kv_buffer(
|
161
166
|
self,
|
162
167
|
layer: RadixAttention,
|
@@ -167,7 +172,7 @@ class BaseTokenToKVPool:
|
|
167
172
|
raise NotImplementedError()
|
168
173
|
|
169
174
|
|
170
|
-
class MHATokenToKVPool(
|
175
|
+
class MHATokenToKVPool(KVCache):
|
171
176
|
|
172
177
|
def __init__(
|
173
178
|
self,
|
@@ -179,8 +184,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
179
184
|
device: str,
|
180
185
|
enable_memory_saver: bool,
|
181
186
|
):
|
182
|
-
|
183
|
-
|
187
|
+
self.size = size
|
188
|
+
self.dtype = dtype
|
189
|
+
self.device = device
|
190
|
+
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
191
|
+
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
192
|
+
self.store_dtype = torch.uint8
|
193
|
+
else:
|
194
|
+
self.store_dtype = dtype
|
184
195
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
185
196
|
enable=enable_memory_saver
|
186
197
|
)
|
@@ -192,7 +203,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
192
203
|
|
193
204
|
k_size, v_size = self.get_kv_size_bytes()
|
194
205
|
logger.info(
|
195
|
-
f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB
|
206
|
+
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
196
207
|
)
|
197
208
|
|
198
209
|
def _create_buffers(self):
|
@@ -297,7 +308,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
|
297
308
|
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
298
309
|
|
299
310
|
|
300
|
-
class MLATokenToKVPool(
|
311
|
+
class MLATokenToKVPool(KVCache):
|
301
312
|
def __init__(
|
302
313
|
self,
|
303
314
|
size: int,
|
@@ -308,8 +319,14 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
308
319
|
device: str,
|
309
320
|
enable_memory_saver: bool,
|
310
321
|
):
|
311
|
-
|
312
|
-
|
322
|
+
self.size = size
|
323
|
+
self.dtype = dtype
|
324
|
+
self.device = device
|
325
|
+
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
326
|
+
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
327
|
+
self.store_dtype = torch.uint8
|
328
|
+
else:
|
329
|
+
self.store_dtype = dtype
|
313
330
|
self.kv_lora_rank = kv_lora_rank
|
314
331
|
|
315
332
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
@@ -356,7 +373,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
356
373
|
self.kv_buffer[layer_id][loc] = cache_k
|
357
374
|
|
358
375
|
|
359
|
-
class DoubleSparseTokenToKVPool(
|
376
|
+
class DoubleSparseTokenToKVPool(KVCache):
|
360
377
|
def __init__(
|
361
378
|
self,
|
362
379
|
size: int,
|
@@ -368,8 +385,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
|
368
385
|
heavy_channel_num: int,
|
369
386
|
enable_memory_saver: bool,
|
370
387
|
):
|
371
|
-
|
372
|
-
|
388
|
+
self.size = size
|
389
|
+
self.dtype = dtype
|
390
|
+
self.device = device
|
391
|
+
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
392
|
+
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
393
|
+
self.store_dtype = torch.uint8
|
394
|
+
else:
|
395
|
+
self.store_dtype = dtype
|
373
396
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
374
397
|
enable=enable_memory_saver
|
375
398
|
)
|
@@ -437,7 +460,7 @@ def synchronized(func):
|
|
437
460
|
return wrapper
|
438
461
|
|
439
462
|
|
440
|
-
class
|
463
|
+
class MHATokenToKVPoolHost:
|
441
464
|
|
442
465
|
def __init__(
|
443
466
|
self,
|
@@ -502,6 +525,9 @@ class MLATokenToKVPoolHost:
|
|
502
525
|
def get_flat_data(self, indices):
|
503
526
|
return self.kv_buffer[:, :, indices]
|
504
527
|
|
528
|
+
def assign_flat_data(self, indices, flat_data):
|
529
|
+
self.kv_buffer[:, :, indices] = flat_data
|
530
|
+
|
505
531
|
@debug_timing
|
506
532
|
def transfer(self, indices, flat_data):
|
507
533
|
# backup prepared data from device to host
|