sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -100,7 +100,7 @@ class TpModelWorkerClient:
|
|
100
100
|
def get_memory_pool(self):
|
101
101
|
return (
|
102
102
|
self.worker.model_runner.req_to_token_pool,
|
103
|
-
self.worker.model_runner.
|
103
|
+
self.worker.model_runner.token_to_kv_pool_allocator,
|
104
104
|
)
|
105
105
|
|
106
106
|
def forward_thread_func(self):
|
@@ -175,7 +175,7 @@ class TpModelWorkerClient:
|
|
175
175
|
logits_output.next_token_logprobs.tolist()
|
176
176
|
)
|
177
177
|
if logits_output.input_token_logprobs is not None:
|
178
|
-
logits_output.input_token_logprobs = (
|
178
|
+
logits_output.input_token_logprobs = tuple(
|
179
179
|
logits_output.input_token_logprobs.tolist()
|
180
180
|
)
|
181
181
|
next_token_ids = next_token_ids.tolist()
|
@@ -188,8 +188,7 @@ class TpModelWorkerClient:
|
|
188
188
|
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
|
189
189
|
sampling_info,
|
190
190
|
sampling_info_done=threading.Event(),
|
191
|
-
|
192
|
-
linear_penalties=sampling_info.linear_penalties,
|
191
|
+
penalizer_orchestrator=None,
|
193
192
|
)
|
194
193
|
|
195
194
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
@@ -1,29 +1,33 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
4
|
+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
4
5
|
|
5
|
-
|
6
|
+
import torch
|
6
7
|
|
7
8
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
8
|
-
from sglang.srt.mem_cache.memory_pool import
|
9
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
9
10
|
|
10
11
|
if TYPE_CHECKING:
|
11
12
|
from sglang.srt.managers.schedule_batch import Req
|
12
13
|
|
13
14
|
|
14
15
|
class ChunkCacheEntry:
|
15
|
-
def __init__(self, rid, value):
|
16
|
+
def __init__(self, rid: str, value: torch.Tensor):
|
16
17
|
self.rid = rid
|
17
18
|
self.value = value
|
18
19
|
|
19
20
|
|
20
21
|
class ChunkCache(BasePrefixCache):
|
21
22
|
def __init__(
|
22
|
-
self,
|
23
|
+
self,
|
24
|
+
req_to_token_pool: ReqToTokenPool,
|
25
|
+
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
23
26
|
):
|
24
27
|
self.disable = True
|
25
28
|
self.req_to_token_pool = req_to_token_pool
|
26
|
-
self.
|
29
|
+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
30
|
+
self.entries: Dict[str, ChunkCacheEntry] = {}
|
27
31
|
|
28
32
|
self.reset()
|
29
33
|
|
@@ -48,16 +52,13 @@ class ChunkCache(BasePrefixCache):
|
|
48
52
|
req.req_pool_idx, :token_id_len
|
49
53
|
]
|
50
54
|
self.req_to_token_pool.free(req.req_pool_idx)
|
51
|
-
self.
|
55
|
+
self.token_to_kv_pool_allocator.free(kv_indices)
|
52
56
|
|
53
57
|
if req.rid in self.entries:
|
54
58
|
del self.entries[req.rid]
|
55
59
|
|
56
|
-
def cache_unfinished_req(self, req: Req
|
57
|
-
|
58
|
-
token_id_len = len(req.fill_ids)
|
59
|
-
else:
|
60
|
-
token_id_len = len(token_ids)
|
60
|
+
def cache_unfinished_req(self, req: Req):
|
61
|
+
token_id_len = len(req.fill_ids)
|
61
62
|
|
62
63
|
kv_indices = self.req_to_token_pool.req_to_token[
|
63
64
|
req.req_pool_idx, :token_id_len
|
@@ -86,5 +87,11 @@ class ChunkCache(BasePrefixCache):
|
|
86
87
|
def evictable_size(self):
|
87
88
|
return 0
|
88
89
|
|
90
|
+
def pretty_print(self):
|
91
|
+
return ""
|
92
|
+
|
89
93
|
def protected_size(self):
|
90
94
|
return 0
|
95
|
+
|
96
|
+
def pretty_print(self):
|
97
|
+
return ""
|
@@ -0,0 +1,394 @@
|
|
1
|
+
import heapq
|
2
|
+
import logging
|
3
|
+
import time
|
4
|
+
from typing import List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.managers.cache_controller import HiCacheController
|
9
|
+
from sglang.srt.mem_cache.memory_pool import (
|
10
|
+
MHATokenToKVPool,
|
11
|
+
MHATokenToKVPoolHost,
|
12
|
+
ReqToTokenPool,
|
13
|
+
)
|
14
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class HiRadixCache(RadixCache):
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
req_to_token_pool: ReqToTokenPool,
|
24
|
+
token_to_kv_pool: MHATokenToKVPool,
|
25
|
+
):
|
26
|
+
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
|
27
|
+
self.cache_controller = HiCacheController(
|
28
|
+
token_to_kv_pool, self.token_to_kv_pool_host
|
29
|
+
)
|
30
|
+
|
31
|
+
# record the nodes with ongoing write through
|
32
|
+
self.ongoing_write_through = {}
|
33
|
+
# record the node segments with ongoing load back
|
34
|
+
self.ongoing_load_back = {}
|
35
|
+
# todo: dynamically adjust the threshold
|
36
|
+
self.write_through_threshold = 1
|
37
|
+
self.load_back_threshold = 10
|
38
|
+
super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
|
39
|
+
|
40
|
+
def reset(self):
|
41
|
+
TreeNode.counter = 0
|
42
|
+
self.cache_controller.reset()
|
43
|
+
self.token_to_kv_pool_host.clear()
|
44
|
+
super().reset()
|
45
|
+
|
46
|
+
def get_height(self, node: TreeNode):
|
47
|
+
height = 0
|
48
|
+
while node != self.root_node:
|
49
|
+
node = node.parent
|
50
|
+
height += 1
|
51
|
+
return height
|
52
|
+
|
53
|
+
def write_backup(self, node: TreeNode):
|
54
|
+
host_indices = self.cache_controller.write(
|
55
|
+
device_indices=node.value,
|
56
|
+
priority=-self.get_height(node),
|
57
|
+
node_id=node.id,
|
58
|
+
)
|
59
|
+
if host_indices is None:
|
60
|
+
self.evict_host(len(node.value))
|
61
|
+
host_indices = self.cache_controller.write(
|
62
|
+
device_indices=node.value,
|
63
|
+
priority=-self.get_height(node),
|
64
|
+
node_id=node.id,
|
65
|
+
)
|
66
|
+
if host_indices is not None:
|
67
|
+
node.host_value = host_indices
|
68
|
+
self.ongoing_write_through[node.id] = node
|
69
|
+
self.inc_lock_ref(node)
|
70
|
+
else:
|
71
|
+
return None
|
72
|
+
|
73
|
+
return len(host_indices)
|
74
|
+
|
75
|
+
def inc_hit_count(self, node: TreeNode):
|
76
|
+
if self.cache_controller.write_policy != "write_through_selective":
|
77
|
+
return
|
78
|
+
node.hit_count += 1
|
79
|
+
if node.host_value is None and node.hit_count > self.write_through_threshold:
|
80
|
+
self.write_backup(node)
|
81
|
+
node.hit_count = 0
|
82
|
+
|
83
|
+
def writing_check(self):
|
84
|
+
while not self.cache_controller.ack_write_queue.empty():
|
85
|
+
try:
|
86
|
+
ack_id = self.cache_controller.ack_write_queue.get_nowait()
|
87
|
+
self.dec_lock_ref(self.ongoing_write_through[ack_id])
|
88
|
+
# clear the reference
|
89
|
+
del self.ongoing_write_through[ack_id]
|
90
|
+
except Exception:
|
91
|
+
break
|
92
|
+
|
93
|
+
def loading_check(self):
|
94
|
+
while not self.cache_controller.ack_load_queue.empty():
|
95
|
+
try:
|
96
|
+
ack_id = self.cache_controller.ack_load_queue.get_nowait()
|
97
|
+
start_node, end_node = self.ongoing_load_back[ack_id]
|
98
|
+
self.dec_lock_ref(end_node)
|
99
|
+
while end_node != start_node:
|
100
|
+
assert end_node.loading
|
101
|
+
end_node.loading = False
|
102
|
+
end_node = end_node.parent
|
103
|
+
# clear the reference
|
104
|
+
del self.ongoing_load_back[ack_id]
|
105
|
+
except Exception:
|
106
|
+
break
|
107
|
+
|
108
|
+
def evictable_size(self):
|
109
|
+
self.writing_check()
|
110
|
+
self.loading_check()
|
111
|
+
return self.evictable_size_
|
112
|
+
|
113
|
+
def evict(self, num_tokens: int, evict_callback=None):
|
114
|
+
leaves = self._collect_leaves_device()
|
115
|
+
heapq.heapify(leaves)
|
116
|
+
|
117
|
+
num_evicted = 0
|
118
|
+
pending_nodes = []
|
119
|
+
while num_evicted < num_tokens and len(leaves):
|
120
|
+
x = heapq.heappop(leaves)
|
121
|
+
|
122
|
+
if x.lock_ref > 0:
|
123
|
+
continue
|
124
|
+
|
125
|
+
if x.host_value is None:
|
126
|
+
if self.cache_controller.write_policy == "write_back":
|
127
|
+
num_evicted += self.write_backup(x)
|
128
|
+
elif self.cache_controller.write_policy == "write_through_selective":
|
129
|
+
num_evicted += self._evict_write_through_selective(x)
|
130
|
+
else:
|
131
|
+
assert (
|
132
|
+
self.cache_controller.write_policy != "write_through"
|
133
|
+
), "write_through should be inclusive"
|
134
|
+
raise NotImplementedError
|
135
|
+
else:
|
136
|
+
num_evicted += self._evict_write_through(x)
|
137
|
+
|
138
|
+
for child in x.parent.children.values():
|
139
|
+
if child in pending_nodes:
|
140
|
+
continue
|
141
|
+
if not child.evicted:
|
142
|
+
break
|
143
|
+
else:
|
144
|
+
# all children are evicted or no children
|
145
|
+
heapq.heappush(leaves, x.parent)
|
146
|
+
|
147
|
+
if self.cache_controller.write_policy == "write_back":
|
148
|
+
# blocking till all write back complete
|
149
|
+
while len(self.ongoing_write_through) > 0:
|
150
|
+
self.writing_check()
|
151
|
+
time.sleep(0.1)
|
152
|
+
|
153
|
+
def _evict_write_through(self, node: TreeNode):
|
154
|
+
# evict a node already written to host
|
155
|
+
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
|
156
|
+
assert num_evicted > 0
|
157
|
+
self.evictable_size_ -= num_evicted
|
158
|
+
node.value = None
|
159
|
+
return num_evicted
|
160
|
+
|
161
|
+
def _evict_write_through_selective(self, node: TreeNode):
|
162
|
+
# evict a node not initiated write to host
|
163
|
+
self.cache_controller.mem_pool_device.free(node.value)
|
164
|
+
num_evicted = len(node.value)
|
165
|
+
self._delete_leaf(node)
|
166
|
+
return num_evicted
|
167
|
+
|
168
|
+
def evict_host(self, num_tokens: int):
|
169
|
+
leaves = self._collect_leaves()
|
170
|
+
heapq.heapify(leaves)
|
171
|
+
|
172
|
+
num_evicted = 0
|
173
|
+
while num_evicted < num_tokens and len(leaves):
|
174
|
+
x = heapq.heappop(leaves)
|
175
|
+
if x == self.root_node:
|
176
|
+
break
|
177
|
+
# only evict the host value of evicted nodes
|
178
|
+
if not x.evicted:
|
179
|
+
continue
|
180
|
+
assert x.lock_ref == 0 and x.host_value is not None
|
181
|
+
|
182
|
+
assert self.cache_controller.evict_host(x.host_value) > 0
|
183
|
+
for k, v in x.parent.children.items():
|
184
|
+
if v == x:
|
185
|
+
break
|
186
|
+
del x.parent.children[k]
|
187
|
+
|
188
|
+
if len(x.parent.children) == 0 and x.parent.evicted:
|
189
|
+
heapq.heappush(leaves, x.parent)
|
190
|
+
|
191
|
+
def load_back(
|
192
|
+
self, node: TreeNode, mem_quota: Optional[int] = None
|
193
|
+
) -> Optional[torch.Tensor]:
|
194
|
+
# todo: more loading policies
|
195
|
+
|
196
|
+
last_hit_node = node
|
197
|
+
nodes_to_load = []
|
198
|
+
while node.evicted:
|
199
|
+
assert (
|
200
|
+
node.backuped
|
201
|
+
), "No backup available on evicted nodes, should not happen"
|
202
|
+
nodes_to_load.insert(0, node)
|
203
|
+
node = node.parent
|
204
|
+
else:
|
205
|
+
ancester_node = node
|
206
|
+
|
207
|
+
# protect the ancestor nodes from eviction
|
208
|
+
delta = self.inc_lock_ref(ancester_node)
|
209
|
+
|
210
|
+
# load it all or not at all
|
211
|
+
host_indices = torch.cat([n.host_value for n in nodes_to_load])
|
212
|
+
if len(host_indices) < self.load_back_threshold or (
|
213
|
+
len(host_indices) > mem_quota + delta if mem_quota is not None else False
|
214
|
+
):
|
215
|
+
# skip loading back if the total size is too small or exceeding the memory quota
|
216
|
+
self.dec_lock_ref(ancester_node)
|
217
|
+
return None
|
218
|
+
|
219
|
+
device_indices = self.cache_controller.load(
|
220
|
+
host_indices=host_indices, node_id=last_hit_node.id
|
221
|
+
)
|
222
|
+
if device_indices is None:
|
223
|
+
self.evict(len(host_indices))
|
224
|
+
device_indices = self.cache_controller.load(
|
225
|
+
host_indices=host_indices, node_id=last_hit_node.id
|
226
|
+
)
|
227
|
+
self.dec_lock_ref(ancester_node)
|
228
|
+
if device_indices is None:
|
229
|
+
# no sufficient GPU memory to load back KV caches
|
230
|
+
return None
|
231
|
+
|
232
|
+
self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
|
233
|
+
offset = 0
|
234
|
+
for node in nodes_to_load:
|
235
|
+
node.value = device_indices[offset : offset + len(node.host_value)]
|
236
|
+
offset += len(node.host_value)
|
237
|
+
node.loading = True
|
238
|
+
self.evictable_size_ += len(device_indices)
|
239
|
+
self.inc_lock_ref(last_hit_node)
|
240
|
+
|
241
|
+
return device_indices
|
242
|
+
|
243
|
+
def loading_complete(self, node: TreeNode):
|
244
|
+
self.loading_check()
|
245
|
+
return node.loading == False
|
246
|
+
|
247
|
+
def init_load_back(
|
248
|
+
self,
|
249
|
+
last_node: TreeNode,
|
250
|
+
prefix_indices: torch.Tensor,
|
251
|
+
mem_quota: Optional[int] = None,
|
252
|
+
):
|
253
|
+
assert (
|
254
|
+
len(prefix_indices) == 0 or prefix_indices.is_cuda
|
255
|
+
), "indices of device kV caches should be on GPU"
|
256
|
+
if last_node.evicted:
|
257
|
+
loading_values = self.load_back(last_node, mem_quota)
|
258
|
+
if loading_values is not None:
|
259
|
+
prefix_indices = (
|
260
|
+
loading_values
|
261
|
+
if len(prefix_indices) == 0
|
262
|
+
else torch.cat([prefix_indices, loading_values])
|
263
|
+
)
|
264
|
+
logger.debug(
|
265
|
+
f"loading back {len(loading_values)} tokens for node {last_node.id}"
|
266
|
+
)
|
267
|
+
|
268
|
+
while last_node.evicted:
|
269
|
+
last_node = last_node.parent
|
270
|
+
|
271
|
+
return last_node, prefix_indices
|
272
|
+
|
273
|
+
def _match_prefix_helper(
|
274
|
+
self, node: TreeNode, key: List, value, last_node: TreeNode
|
275
|
+
):
|
276
|
+
node.last_access_time = time.time()
|
277
|
+
if len(key) == 0:
|
278
|
+
return
|
279
|
+
|
280
|
+
if key[0] in node.children.keys():
|
281
|
+
child = node.children[key[0]]
|
282
|
+
prefix_len = _key_match(child.key, key)
|
283
|
+
if prefix_len < len(child.key):
|
284
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
285
|
+
self.inc_hit_count(new_node)
|
286
|
+
if not new_node.evicted:
|
287
|
+
value.append(new_node.value)
|
288
|
+
last_node[0] = new_node
|
289
|
+
else:
|
290
|
+
self.inc_hit_count(child)
|
291
|
+
if not child.evicted:
|
292
|
+
value.append(child.value)
|
293
|
+
last_node[0] = child
|
294
|
+
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
295
|
+
|
296
|
+
def _split_node(self, key, child: TreeNode, split_len: int):
|
297
|
+
# child node split into new_node -> child
|
298
|
+
new_node = TreeNode()
|
299
|
+
new_node.children = {key[split_len]: child}
|
300
|
+
new_node.parent = child.parent
|
301
|
+
new_node.lock_ref = child.lock_ref
|
302
|
+
new_node.key = child.key[:split_len]
|
303
|
+
new_node.loading = child.loading
|
304
|
+
|
305
|
+
# split value and host value if exists
|
306
|
+
if child.evicted:
|
307
|
+
new_node.value = None
|
308
|
+
else:
|
309
|
+
new_node.value = child.value[:split_len]
|
310
|
+
child.value = child.value[split_len:]
|
311
|
+
if child.host_value is not None:
|
312
|
+
new_node.host_value = child.host_value[:split_len]
|
313
|
+
child.host_value = child.host_value[split_len:]
|
314
|
+
child.parent = new_node
|
315
|
+
child.key = child.key[split_len:]
|
316
|
+
new_node.parent.children[key[0]] = new_node
|
317
|
+
return new_node
|
318
|
+
|
319
|
+
def _insert_helper(self, node: TreeNode, key: List, value):
|
320
|
+
node.last_access_time = time.time()
|
321
|
+
if len(key) == 0:
|
322
|
+
return 0
|
323
|
+
|
324
|
+
if key[0] in node.children.keys():
|
325
|
+
child = node.children[key[0]]
|
326
|
+
prefix_len = _key_match(child.key, key)
|
327
|
+
|
328
|
+
if prefix_len == len(child.key):
|
329
|
+
if child.evicted:
|
330
|
+
# change the reference if the node is evicted
|
331
|
+
# this often happens in the case of KV cache recomputation
|
332
|
+
child.value = value[:prefix_len]
|
333
|
+
self.token_to_kv_pool_host.update_synced(child.host_value)
|
334
|
+
self.evictable_size_ += len(value[:prefix_len])
|
335
|
+
return self._insert_helper(
|
336
|
+
child, key[prefix_len:], value[prefix_len:]
|
337
|
+
)
|
338
|
+
else:
|
339
|
+
self.inc_hit_count(child)
|
340
|
+
return prefix_len + self._insert_helper(
|
341
|
+
child, key[prefix_len:], value[prefix_len:]
|
342
|
+
)
|
343
|
+
|
344
|
+
# partial match, split the node
|
345
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
346
|
+
if new_node.evicted:
|
347
|
+
new_node.value = value[:prefix_len]
|
348
|
+
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
349
|
+
self.evictable_size_ += len(new_node.value)
|
350
|
+
return self._insert_helper(
|
351
|
+
new_node, key[prefix_len:], value[prefix_len:]
|
352
|
+
)
|
353
|
+
else:
|
354
|
+
self.inc_hit_count(new_node)
|
355
|
+
return prefix_len + self._insert_helper(
|
356
|
+
new_node, key[prefix_len:], value[prefix_len:]
|
357
|
+
)
|
358
|
+
|
359
|
+
if len(key):
|
360
|
+
new_node = TreeNode()
|
361
|
+
new_node.parent = node
|
362
|
+
new_node.key = key
|
363
|
+
new_node.value = value
|
364
|
+
node.children[key[0]] = new_node
|
365
|
+
self.evictable_size_ += len(value)
|
366
|
+
|
367
|
+
if self.cache_controller.write_policy == "write_through":
|
368
|
+
self.write_backup(new_node)
|
369
|
+
return 0
|
370
|
+
|
371
|
+
def _collect_leaves_device(self):
|
372
|
+
def is_leaf(node):
|
373
|
+
if node.evicted:
|
374
|
+
return False
|
375
|
+
if node == self.root_node:
|
376
|
+
return False
|
377
|
+
if len(node.children) == 0:
|
378
|
+
return True
|
379
|
+
for child in node.children.values():
|
380
|
+
if not child.evicted:
|
381
|
+
return False
|
382
|
+
return True
|
383
|
+
|
384
|
+
ret_list = []
|
385
|
+
stack = [self.root_node]
|
386
|
+
while stack:
|
387
|
+
cur_node = stack.pop()
|
388
|
+
if is_leaf(cur_node):
|
389
|
+
ret_list.append(cur_node)
|
390
|
+
else:
|
391
|
+
for cur_child in cur_node.children.values():
|
392
|
+
if not cur_child.evicted:
|
393
|
+
stack.append(cur_child)
|
394
|
+
return ret_list
|
@@ -20,9 +20,11 @@ Memory pool.
|
|
20
20
|
|
21
21
|
SGLang has two levels of memory pool.
|
22
22
|
ReqToTokenPool maps a a request to its token locations.
|
23
|
-
|
23
|
+
TokenToKVPoolAllocator manages the indices to kv cache data.
|
24
|
+
KVCache actually holds the physical kv cache.
|
24
25
|
"""
|
25
26
|
|
27
|
+
import abc
|
26
28
|
import logging
|
27
29
|
import threading
|
28
30
|
from enum import IntEnum
|
@@ -89,22 +91,43 @@ class ReqToTokenPool:
|
|
89
91
|
self.free_slots = list(range(self.size))
|
90
92
|
|
91
93
|
|
92
|
-
class
|
93
|
-
|
94
|
+
class KVCache(abc.ABC):
|
95
|
+
|
96
|
+
@abc.abstractmethod
|
97
|
+
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
98
|
+
raise NotImplementedError()
|
99
|
+
|
100
|
+
@abc.abstractmethod
|
101
|
+
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
102
|
+
raise NotImplementedError()
|
103
|
+
|
104
|
+
@abc.abstractmethod
|
105
|
+
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
106
|
+
raise NotImplementedError()
|
107
|
+
|
108
|
+
@abc.abstractmethod
|
109
|
+
def set_kv_buffer(
|
110
|
+
self,
|
111
|
+
layer: RadixAttention,
|
112
|
+
loc: torch.Tensor,
|
113
|
+
cache_k: torch.Tensor,
|
114
|
+
cache_v: torch.Tensor,
|
115
|
+
) -> None:
|
116
|
+
raise NotImplementedError()
|
117
|
+
|
118
|
+
|
119
|
+
class TokenToKVPoolAllocator:
|
120
|
+
"""An allocator managing the indices to kv cache data."""
|
94
121
|
|
95
122
|
def __init__(
|
96
123
|
self,
|
97
124
|
size: int,
|
98
125
|
dtype: torch.dtype,
|
99
126
|
device: str,
|
127
|
+
kvcache: KVCache,
|
100
128
|
):
|
101
129
|
self.size = size
|
102
130
|
self.dtype = dtype
|
103
|
-
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
104
|
-
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
105
|
-
self.store_dtype = torch.uint8
|
106
|
-
else:
|
107
|
-
self.store_dtype = dtype
|
108
131
|
self.device = device
|
109
132
|
|
110
133
|
self.free_slots = None
|
@@ -112,9 +135,14 @@ class BaseTokenToKVPool:
|
|
112
135
|
self.free_group = []
|
113
136
|
self.clear()
|
114
137
|
|
138
|
+
self._kvcache = kvcache
|
139
|
+
|
115
140
|
def available_size(self):
|
116
141
|
return len(self.free_slots)
|
117
142
|
|
143
|
+
def get_kvcache(self):
|
144
|
+
return self._kvcache
|
145
|
+
|
118
146
|
def alloc(self, need_size: int):
|
119
147
|
if need_size > len(self.free_slots):
|
120
148
|
return None
|
@@ -148,26 +176,8 @@ class BaseTokenToKVPool:
|
|
148
176
|
self.is_in_free_group = False
|
149
177
|
self.free_group = []
|
150
178
|
|
151
|
-
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
152
|
-
raise NotImplementedError()
|
153
|
-
|
154
|
-
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
155
|
-
raise NotImplementedError()
|
156
179
|
|
157
|
-
|
158
|
-
raise NotImplementedError()
|
159
|
-
|
160
|
-
def set_kv_buffer(
|
161
|
-
self,
|
162
|
-
layer: RadixAttention,
|
163
|
-
loc: torch.Tensor,
|
164
|
-
cache_k: torch.Tensor,
|
165
|
-
cache_v: torch.Tensor,
|
166
|
-
) -> None:
|
167
|
-
raise NotImplementedError()
|
168
|
-
|
169
|
-
|
170
|
-
class MHATokenToKVPool(BaseTokenToKVPool):
|
180
|
+
class MHATokenToKVPool(KVCache):
|
171
181
|
|
172
182
|
def __init__(
|
173
183
|
self,
|
@@ -179,8 +189,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
179
189
|
device: str,
|
180
190
|
enable_memory_saver: bool,
|
181
191
|
):
|
182
|
-
|
183
|
-
|
192
|
+
self.size = size
|
193
|
+
self.dtype = dtype
|
194
|
+
self.device = device
|
195
|
+
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
196
|
+
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
197
|
+
self.store_dtype = torch.uint8
|
198
|
+
else:
|
199
|
+
self.store_dtype = dtype
|
184
200
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
185
201
|
enable=enable_memory_saver
|
186
202
|
)
|
@@ -192,7 +208,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
192
208
|
|
193
209
|
k_size, v_size = self.get_kv_size_bytes()
|
194
210
|
logger.info(
|
195
|
-
f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB
|
211
|
+
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
196
212
|
)
|
197
213
|
|
198
214
|
def _create_buffers(self):
|
@@ -297,7 +313,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
|
297
313
|
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
298
314
|
|
299
315
|
|
300
|
-
class MLATokenToKVPool(
|
316
|
+
class MLATokenToKVPool(KVCache):
|
301
317
|
def __init__(
|
302
318
|
self,
|
303
319
|
size: int,
|
@@ -308,8 +324,14 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
308
324
|
device: str,
|
309
325
|
enable_memory_saver: bool,
|
310
326
|
):
|
311
|
-
|
312
|
-
|
327
|
+
self.size = size
|
328
|
+
self.dtype = dtype
|
329
|
+
self.device = device
|
330
|
+
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
331
|
+
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
332
|
+
self.store_dtype = torch.uint8
|
333
|
+
else:
|
334
|
+
self.store_dtype = dtype
|
313
335
|
self.kv_lora_rank = kv_lora_rank
|
314
336
|
|
315
337
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
@@ -356,7 +378,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
356
378
|
self.kv_buffer[layer_id][loc] = cache_k
|
357
379
|
|
358
380
|
|
359
|
-
class DoubleSparseTokenToKVPool(
|
381
|
+
class DoubleSparseTokenToKVPool(KVCache):
|
360
382
|
def __init__(
|
361
383
|
self,
|
362
384
|
size: int,
|
@@ -368,8 +390,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
|
368
390
|
heavy_channel_num: int,
|
369
391
|
enable_memory_saver: bool,
|
370
392
|
):
|
371
|
-
|
372
|
-
|
393
|
+
self.size = size
|
394
|
+
self.dtype = dtype
|
395
|
+
self.device = device
|
396
|
+
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
397
|
+
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
398
|
+
self.store_dtype = torch.uint8
|
399
|
+
else:
|
400
|
+
self.store_dtype = dtype
|
373
401
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
374
402
|
enable=enable_memory_saver
|
375
403
|
)
|
@@ -437,7 +465,7 @@ def synchronized(func):
|
|
437
465
|
return wrapper
|
438
466
|
|
439
467
|
|
440
|
-
class
|
468
|
+
class MHATokenToKVPoolHost:
|
441
469
|
|
442
470
|
def __init__(
|
443
471
|
self,
|
@@ -502,6 +530,9 @@ class MLATokenToKVPoolHost:
|
|
502
530
|
def get_flat_data(self, indices):
|
503
531
|
return self.kv_buffer[:, :, indices]
|
504
532
|
|
533
|
+
def assign_flat_data(self, indices, flat_data):
|
534
|
+
self.kv_buffer[:, :, indices] = flat_data
|
535
|
+
|
505
536
|
@debug_timing
|
506
537
|
def transfer(self, indices, flat_data):
|
507
538
|
# backup prepared data from device to host
|