sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
from __future__ import annotations
|
17
|
+
|
16
18
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
17
19
|
|
18
20
|
"""
|
@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
|
|
27
29
|
import abc
|
28
30
|
import logging
|
29
31
|
from contextlib import nullcontext
|
30
|
-
from typing import Dict, List, Optional, Tuple, Union
|
32
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
31
33
|
|
32
34
|
import numpy as np
|
33
35
|
import torch
|
@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
|
38
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
41
|
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
40
42
|
|
43
|
+
if TYPE_CHECKING:
|
44
|
+
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
45
|
+
|
41
46
|
logger = logging.getLogger(__name__)
|
42
47
|
|
43
48
|
GB = 1024 * 1024 * 1024
|
@@ -97,6 +102,207 @@ class ReqToTokenPool:
|
|
97
102
|
self.free_slots = list(range(self.size))
|
98
103
|
|
99
104
|
|
105
|
+
class MambaPool:
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
size: int,
|
109
|
+
conv_dtype: torch.dtype,
|
110
|
+
ssm_dtype: torch.dtype,
|
111
|
+
num_mamba_layers: int,
|
112
|
+
conv_state_shape: Tuple[int, int],
|
113
|
+
temporal_state_shape: Tuple[int, int],
|
114
|
+
device: str,
|
115
|
+
speculative_num_draft_tokens: Optional[int] = None,
|
116
|
+
):
|
117
|
+
conv_state = torch.zeros(
|
118
|
+
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
119
|
+
dtype=conv_dtype,
|
120
|
+
device=device,
|
121
|
+
)
|
122
|
+
temporal_state = torch.zeros(
|
123
|
+
size=(num_mamba_layers, size + 1) + temporal_state_shape,
|
124
|
+
dtype=ssm_dtype,
|
125
|
+
device=device,
|
126
|
+
)
|
127
|
+
if speculative_num_draft_tokens is not None:
|
128
|
+
# Cache intermediate SSM states per draft token during target verify
|
129
|
+
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
130
|
+
intermediate_ssm_state_cache = torch.empty(
|
131
|
+
size=(
|
132
|
+
num_mamba_layers,
|
133
|
+
size + 1,
|
134
|
+
speculative_num_draft_tokens,
|
135
|
+
temporal_state_shape[0],
|
136
|
+
temporal_state_shape[1],
|
137
|
+
temporal_state_shape[2],
|
138
|
+
),
|
139
|
+
dtype=ssm_dtype,
|
140
|
+
device="cuda",
|
141
|
+
)
|
142
|
+
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
143
|
+
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
144
|
+
intermediate_conv_window_cache = torch.empty(
|
145
|
+
size=(
|
146
|
+
num_mamba_layers,
|
147
|
+
size + 1,
|
148
|
+
speculative_num_draft_tokens,
|
149
|
+
conv_state_shape[0],
|
150
|
+
conv_state_shape[1],
|
151
|
+
),
|
152
|
+
dtype=conv_dtype,
|
153
|
+
device="cuda",
|
154
|
+
)
|
155
|
+
self.mamba_cache = (
|
156
|
+
conv_state,
|
157
|
+
temporal_state,
|
158
|
+
intermediate_ssm_state_cache,
|
159
|
+
intermediate_conv_window_cache,
|
160
|
+
)
|
161
|
+
else:
|
162
|
+
self.mamba_cache = (conv_state, temporal_state)
|
163
|
+
self.size = size
|
164
|
+
self.free_slots = list(range(size))
|
165
|
+
self.mem_usage = self.get_mamba_size() / GB
|
166
|
+
logger.info(
|
167
|
+
f"Mamba Cache is allocated. "
|
168
|
+
f"conv_state size: {conv_state.numel() * conv_state.itemsize / GB:.2f}GB, "
|
169
|
+
f"ssm_state size: {temporal_state.numel() * temporal_state.itemsize / GB:.2f}GB "
|
170
|
+
)
|
171
|
+
|
172
|
+
def get_mamba_params_all_layers(self):
|
173
|
+
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
|
174
|
+
|
175
|
+
def get_mamba_params(self, layer_id: int):
|
176
|
+
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
|
177
|
+
|
178
|
+
def get_mamba_size(self):
|
179
|
+
return (
|
180
|
+
np.prod(self.mamba_cache[0].shape) * self.mamba_cache[0].dtype.itemsize
|
181
|
+
+ np.prod(self.mamba_cache[1].shape) * self.mamba_cache[1].dtype.itemsize
|
182
|
+
)
|
183
|
+
|
184
|
+
def available_size(self):
|
185
|
+
return len(self.free_slots)
|
186
|
+
|
187
|
+
def alloc(self, need_size: int) -> Optional[List[int]]:
|
188
|
+
if need_size > len(self.free_slots):
|
189
|
+
return None
|
190
|
+
|
191
|
+
select_index = self.free_slots[:need_size]
|
192
|
+
self.free_slots = self.free_slots[need_size:]
|
193
|
+
|
194
|
+
return select_index
|
195
|
+
|
196
|
+
def free(self, free_index: Union[int, List[int]]):
|
197
|
+
if isinstance(free_index, (int,)):
|
198
|
+
self.free_slots.append(free_index)
|
199
|
+
else:
|
200
|
+
self.free_slots.extend(free_index)
|
201
|
+
self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
|
202
|
+
|
203
|
+
def clear(self):
|
204
|
+
self.free_slots = list(range(self.size))
|
205
|
+
|
206
|
+
|
207
|
+
class HybridReqToTokenPool(ReqToTokenPool):
|
208
|
+
"""A memory pool that maps a request to its token locations."""
|
209
|
+
|
210
|
+
def __init__(
|
211
|
+
self,
|
212
|
+
size: int,
|
213
|
+
max_context_len: int,
|
214
|
+
device: str,
|
215
|
+
enable_memory_saver: bool,
|
216
|
+
conv_dtype: torch.dtype,
|
217
|
+
ssm_dtype: torch.dtype,
|
218
|
+
mamba_layers: List[int],
|
219
|
+
conv_state_shape: Tuple[int, int],
|
220
|
+
temporal_state_shape: Tuple[int, int],
|
221
|
+
speculative_num_draft_tokens: int,
|
222
|
+
):
|
223
|
+
super().__init__(
|
224
|
+
size=size,
|
225
|
+
max_context_len=max_context_len,
|
226
|
+
device=device,
|
227
|
+
enable_memory_saver=enable_memory_saver,
|
228
|
+
)
|
229
|
+
|
230
|
+
self.mamba_pool = MambaPool(
|
231
|
+
size,
|
232
|
+
conv_dtype,
|
233
|
+
ssm_dtype,
|
234
|
+
len(mamba_layers),
|
235
|
+
conv_state_shape,
|
236
|
+
temporal_state_shape,
|
237
|
+
device,
|
238
|
+
speculative_num_draft_tokens,
|
239
|
+
)
|
240
|
+
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
|
241
|
+
|
242
|
+
self.device = device
|
243
|
+
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.empty(
|
244
|
+
size, dtype=torch.int32, device=self.device
|
245
|
+
)
|
246
|
+
|
247
|
+
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
|
248
|
+
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
|
249
|
+
|
250
|
+
# For chunk prefill req, we do not need to allocate mamba cache,
|
251
|
+
# We could use allocated mamba cache instead.
|
252
|
+
def alloc(
|
253
|
+
self, need_size: int, reqs: Optional[List["Req"]] = None
|
254
|
+
) -> Optional[List[int]]:
|
255
|
+
select_index = super().alloc(need_size)
|
256
|
+
if select_index == None:
|
257
|
+
return None
|
258
|
+
|
259
|
+
mamba_index = []
|
260
|
+
for req in reqs:
|
261
|
+
rid = req.rid
|
262
|
+
if rid in self.rid_to_mamba_index_mapping:
|
263
|
+
mid = self.rid_to_mamba_index_mapping[rid]
|
264
|
+
elif (mid := self.mamba_pool.alloc(1)) is not None:
|
265
|
+
mid = mid[0]
|
266
|
+
self.rid_to_mamba_index_mapping[rid] = mid
|
267
|
+
self.mamba_index_to_rid_mapping[mid] = rid
|
268
|
+
mamba_index.append(mid)
|
269
|
+
assert len(select_index) == len(
|
270
|
+
mamba_index
|
271
|
+
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
|
272
|
+
self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
|
273
|
+
mamba_index, dtype=torch.int32, device=self.device
|
274
|
+
)
|
275
|
+
return select_index
|
276
|
+
|
277
|
+
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
|
278
|
+
return self.req_index_to_mamba_index_mapping[req_indices]
|
279
|
+
|
280
|
+
def get_mamba_params(self, layer_id: int):
|
281
|
+
assert layer_id in self.mamba_map
|
282
|
+
return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
|
283
|
+
|
284
|
+
def get_mamba_params_all_layers(self):
|
285
|
+
return self.mamba_pool.get_mamba_params_all_layers()
|
286
|
+
|
287
|
+
# For chunk prefill, we can not free mamba cache, we need use it in the future
|
288
|
+
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
|
289
|
+
super().free(free_index)
|
290
|
+
if free_mamba_cache:
|
291
|
+
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
|
292
|
+
mamba_index_list = mamba_index.tolist()
|
293
|
+
if isinstance(mamba_index_list, int):
|
294
|
+
mamba_index_list = [mamba_index_list]
|
295
|
+
self.mamba_pool.free(mamba_index_list)
|
296
|
+
for mid in mamba_index_list:
|
297
|
+
rid = self.mamba_index_to_rid_mapping[mid]
|
298
|
+
self.mamba_index_to_rid_mapping.pop(mid)
|
299
|
+
self.rid_to_mamba_index_mapping.pop(rid)
|
300
|
+
|
301
|
+
def clear(self):
|
302
|
+
super().clear()
|
303
|
+
self.mamba_pool.clear()
|
304
|
+
|
305
|
+
|
100
306
|
class KVCache(abc.ABC):
|
101
307
|
@abc.abstractmethod
|
102
308
|
def __init__(
|
@@ -130,6 +336,29 @@ class KVCache(abc.ABC):
|
|
130
336
|
# used for chunked cpu-offloading
|
131
337
|
self.cpu_offloading_chunk_size = 8192
|
132
338
|
|
339
|
+
# default state for optional layer-wise transfer control
|
340
|
+
self.layer_transfer_counter = None
|
341
|
+
|
342
|
+
def _finalize_allocation_log(self, num_tokens: int):
|
343
|
+
"""Common logging and mem_usage computation for KV cache allocation.
|
344
|
+
Supports both tuple (K, V) size returns and single KV size returns.
|
345
|
+
"""
|
346
|
+
kv_size_bytes = self.get_kv_size_bytes()
|
347
|
+
if isinstance(kv_size_bytes, tuple):
|
348
|
+
k_size, v_size = kv_size_bytes
|
349
|
+
k_size_GB = k_size / GB
|
350
|
+
v_size_GB = v_size / GB
|
351
|
+
logger.info(
|
352
|
+
f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB"
|
353
|
+
)
|
354
|
+
self.mem_usage = k_size_GB + v_size_GB
|
355
|
+
else:
|
356
|
+
kv_size_GB = kv_size_bytes / GB
|
357
|
+
logger.info(
|
358
|
+
f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB"
|
359
|
+
)
|
360
|
+
self.mem_usage = kv_size_GB
|
361
|
+
|
133
362
|
@abc.abstractmethod
|
134
363
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
135
364
|
raise NotImplementedError()
|
@@ -152,7 +381,7 @@ class KVCache(abc.ABC):
|
|
152
381
|
) -> None:
|
153
382
|
raise NotImplementedError()
|
154
383
|
|
155
|
-
def register_layer_transfer_counter(self, layer_transfer_counter):
|
384
|
+
def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
|
156
385
|
self.layer_transfer_counter = layer_transfer_counter
|
157
386
|
|
158
387
|
def get_cpu_copy(self, indices):
|
@@ -205,15 +434,9 @@ class MHATokenToKVPool(KVCache):
|
|
205
434
|
|
206
435
|
self._create_buffers()
|
207
436
|
|
208
|
-
self.layer_transfer_counter = None
|
209
437
|
self.device_module = torch.get_device_module(self.device)
|
210
438
|
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
211
|
-
|
212
|
-
k_size, v_size = self.get_kv_size_bytes()
|
213
|
-
logger.info(
|
214
|
-
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
215
|
-
)
|
216
|
-
self.mem_usage = (k_size + v_size) / GB
|
439
|
+
self._finalize_allocation_log(size)
|
217
440
|
|
218
441
|
def _create_buffers(self):
|
219
442
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
@@ -352,7 +575,6 @@ class MHATokenToKVPool(KVCache):
|
|
352
575
|
# same applies to get_value_buffer and get_kv_buffer
|
353
576
|
if self.layer_transfer_counter is not None:
|
354
577
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
355
|
-
|
356
578
|
return self._get_key_buffer(layer_id)
|
357
579
|
|
358
580
|
def _get_value_buffer(self, layer_id: int):
|
@@ -420,50 +642,119 @@ class MHATokenToKVPool(KVCache):
|
|
420
642
|
)
|
421
643
|
|
422
644
|
|
423
|
-
class
|
424
|
-
"""KV cache with separate pools for full and
|
645
|
+
class HybridLinearKVPool(KVCache):
|
646
|
+
"""KV cache with separate pools for full and linear attention layers."""
|
425
647
|
|
426
648
|
def __init__(
|
427
649
|
self,
|
428
650
|
size: int,
|
429
|
-
size_swa: int,
|
430
651
|
dtype: torch.dtype,
|
431
652
|
head_num: int,
|
432
653
|
head_dim: int,
|
433
|
-
swa_attention_layer_ids: List[int],
|
434
654
|
full_attention_layer_ids: List[int],
|
435
655
|
enable_kvcache_transpose: bool,
|
436
656
|
device: str,
|
437
657
|
):
|
438
658
|
self.size = size
|
439
|
-
self.size_swa = size_swa
|
440
659
|
self.dtype = dtype
|
441
660
|
self.device = device
|
442
|
-
self.swa_layer_nums = len(swa_attention_layer_ids)
|
443
661
|
self.full_layer_nums = len(full_attention_layer_ids)
|
444
662
|
self.page_size = 1
|
445
663
|
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
446
664
|
assert not enable_kvcache_transpose
|
447
|
-
|
448
|
-
|
449
|
-
size=size_swa,
|
665
|
+
self.full_kv_pool = MHATokenToKVPool(
|
666
|
+
size=size,
|
450
667
|
page_size=self.page_size,
|
451
668
|
dtype=dtype,
|
452
669
|
head_num=head_num,
|
453
670
|
head_dim=head_dim,
|
454
|
-
layer_num=self.
|
671
|
+
layer_num=self.full_layer_nums,
|
455
672
|
device=device,
|
456
673
|
enable_memory_saver=False,
|
457
674
|
)
|
458
|
-
self.
|
675
|
+
self.full_attention_layer_id_mapping = {
|
676
|
+
id: i for i, id in enumerate(full_attention_layer_ids)
|
677
|
+
}
|
678
|
+
k_size, v_size = self.get_kv_size_bytes()
|
679
|
+
self.mem_usage = (k_size + v_size) / GB
|
680
|
+
|
681
|
+
def get_kv_size_bytes(self):
|
682
|
+
return self.full_kv_pool.get_kv_size_bytes()
|
683
|
+
|
684
|
+
def get_contiguous_buf_infos(self):
|
685
|
+
return self.full_kv_pool.get_contiguous_buf_infos()
|
686
|
+
|
687
|
+
def _transfer_full_attention_id(self, layer_id: int):
|
688
|
+
if layer_id not in self.full_attention_layer_id_mapping:
|
689
|
+
raise ValueError(
|
690
|
+
f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
|
691
|
+
)
|
692
|
+
return self.full_attention_layer_id_mapping[layer_id]
|
693
|
+
|
694
|
+
def get_key_buffer(self, layer_id: int):
|
695
|
+
layer_id = self._transfer_full_attention_id(layer_id)
|
696
|
+
return self.full_kv_pool.get_key_buffer(layer_id)
|
697
|
+
|
698
|
+
def get_value_buffer(self, layer_id: int):
|
699
|
+
layer_id = self._transfer_full_attention_id(layer_id)
|
700
|
+
return self.full_kv_pool.get_value_buffer(layer_id)
|
701
|
+
|
702
|
+
def get_kv_buffer(self, layer_id: int):
|
703
|
+
layer_id = self._transfer_full_attention_id(layer_id)
|
704
|
+
return self.full_kv_pool.get_kv_buffer(layer_id)
|
705
|
+
|
706
|
+
def set_kv_buffer(
|
707
|
+
self,
|
708
|
+
layer: RadixAttention,
|
709
|
+
loc: torch.Tensor,
|
710
|
+
cache_k: torch.Tensor,
|
711
|
+
cache_v: torch.Tensor,
|
712
|
+
k_scale: float = 1.0,
|
713
|
+
v_scale: float = 1.0,
|
714
|
+
):
|
715
|
+
layer_id = self._transfer_full_attention_id(layer.layer_id)
|
716
|
+
self.full_kv_pool.set_kv_buffer(
|
717
|
+
None,
|
718
|
+
loc,
|
719
|
+
cache_k,
|
720
|
+
cache_v,
|
721
|
+
k_scale,
|
722
|
+
v_scale,
|
723
|
+
layer_id_override=layer_id,
|
724
|
+
)
|
725
|
+
|
726
|
+
|
727
|
+
class SWAKVPool(KVCache):
|
728
|
+
"""KV cache with separate pools for full and SWA attention layers."""
|
729
|
+
|
730
|
+
def __init__(
|
731
|
+
self,
|
732
|
+
size: int,
|
733
|
+
size_swa: int,
|
734
|
+
swa_attention_layer_ids: List[int],
|
735
|
+
full_attention_layer_ids: List[int],
|
736
|
+
enable_kvcache_transpose: bool,
|
737
|
+
token_to_kv_pool_class: KVCache = MHATokenToKVPool,
|
738
|
+
**kwargs,
|
739
|
+
):
|
740
|
+
self.size = size
|
741
|
+
self.size_swa = size_swa
|
742
|
+
self.swa_layer_nums = len(swa_attention_layer_ids)
|
743
|
+
self.full_layer_nums = len(full_attention_layer_ids)
|
744
|
+
kwargs["page_size"] = 1
|
745
|
+
kwargs["enable_memory_saver"] = False
|
746
|
+
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
747
|
+
assert not enable_kvcache_transpose
|
748
|
+
|
749
|
+
self.swa_kv_pool = token_to_kv_pool_class(
|
750
|
+
size=size_swa,
|
751
|
+
layer_num=self.swa_layer_nums,
|
752
|
+
**kwargs,
|
753
|
+
)
|
754
|
+
self.full_kv_pool = token_to_kv_pool_class(
|
459
755
|
size=size,
|
460
|
-
page_size=self.page_size,
|
461
|
-
dtype=dtype,
|
462
|
-
head_num=head_num,
|
463
|
-
head_dim=head_dim,
|
464
756
|
layer_num=self.full_layer_nums,
|
465
|
-
|
466
|
-
enable_memory_saver=False,
|
757
|
+
**kwargs,
|
467
758
|
)
|
468
759
|
self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
|
469
760
|
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
|
@@ -768,13 +1059,7 @@ class MLATokenToKVPool(KVCache):
|
|
768
1059
|
dtype=torch.uint64,
|
769
1060
|
device=self.device,
|
770
1061
|
)
|
771
|
-
self.
|
772
|
-
|
773
|
-
kv_size = self.get_kv_size_bytes()
|
774
|
-
logger.info(
|
775
|
-
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
776
|
-
)
|
777
|
-
self.mem_usage = kv_size / GB
|
1062
|
+
self._finalize_allocation_log(size)
|
778
1063
|
|
779
1064
|
def get_kv_size_bytes(self):
|
780
1065
|
assert hasattr(self, "kv_buffer")
|
@@ -918,6 +1203,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
918
1203
|
layer_num,
|
919
1204
|
self.size // self.page_size + 1,
|
920
1205
|
self.page_size,
|
1206
|
+
1,
|
921
1207
|
self.kv_lora_rank,
|
922
1208
|
),
|
923
1209
|
dtype=self.store_dtype,
|
@@ -928,19 +1214,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
928
1214
|
layer_num,
|
929
1215
|
self.size // self.page_size + 1,
|
930
1216
|
self.page_size,
|
1217
|
+
1,
|
931
1218
|
self.qk_rope_head_dim,
|
932
1219
|
),
|
933
1220
|
dtype=self.store_dtype,
|
934
1221
|
device=self.device,
|
935
1222
|
)
|
936
1223
|
|
937
|
-
self.
|
938
|
-
|
939
|
-
kv_size = self.get_kv_size_bytes()
|
940
|
-
logger.info(
|
941
|
-
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
942
|
-
)
|
943
|
-
self.mem_usage = kv_size / GB
|
1224
|
+
self._finalize_allocation_log(size)
|
944
1225
|
|
945
1226
|
def get_kv_size_bytes(self):
|
946
1227
|
assert hasattr(self, "k_buffer")
|
@@ -1000,9 +1281,11 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
1000
1281
|
layer_id = layer.layer_id
|
1001
1282
|
if cache_k.dtype != self.dtype:
|
1002
1283
|
cache_k = cache_k.to(self.dtype)
|
1284
|
+
cache_v = cache_v.to(self.dtype)
|
1003
1285
|
|
1004
1286
|
if self.store_dtype != self.dtype:
|
1005
1287
|
cache_k = cache_k.view(self.store_dtype)
|
1288
|
+
cache_v = cache_v.view(self.store_dtype)
|
1006
1289
|
|
1007
1290
|
if cache_v is None:
|
1008
1291
|
cache_k, cache_v = cache_k.split(
|
@@ -3,16 +3,17 @@ import logging
|
|
3
3
|
import threading
|
4
4
|
from enum import IntEnum
|
5
5
|
from functools import wraps
|
6
|
+
from typing import Optional
|
6
7
|
|
7
8
|
import psutil
|
8
9
|
import torch
|
9
10
|
|
10
|
-
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
11
11
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
12
|
-
from sglang.srt.utils import is_npu
|
12
|
+
from sglang.srt.utils import is_npu, is_xpu
|
13
13
|
|
14
14
|
_is_npu = is_npu()
|
15
|
-
|
15
|
+
_is_xpu = is_xpu()
|
16
|
+
if not (_is_npu or _is_xpu):
|
16
17
|
from sgl_kernel.kvcacheio import (
|
17
18
|
transfer_kv_all_layer,
|
18
19
|
transfer_kv_all_layer_lf_pf,
|
@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC):
|
|
169
170
|
return len(self.free_slots)
|
170
171
|
|
171
172
|
@synchronized()
|
172
|
-
def alloc(self, need_size: int) -> torch.Tensor:
|
173
|
+
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
|
173
174
|
assert (
|
174
175
|
need_size % self.page_size == 0
|
175
176
|
), "The requested size should be a multiple of the page size."
|
@@ -464,11 +465,11 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
464
465
|
else:
|
465
466
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
466
467
|
|
467
|
-
def get_buffer_meta(self, keys, indices):
|
468
|
-
local_rank = get_tensor_model_parallel_rank()
|
468
|
+
def get_buffer_meta(self, keys, indices, local_rank):
|
469
469
|
ptr_list = []
|
470
470
|
key_list = []
|
471
471
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
472
|
+
indices = indices.tolist()
|
472
473
|
v_offset = (
|
473
474
|
self.layer_num
|
474
475
|
* self.size
|
@@ -501,20 +502,23 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
501
502
|
element_size_list = [element_size] * len(key_list)
|
502
503
|
return key_list, ptr_list, element_size_list
|
503
504
|
|
504
|
-
def get_buffer_with_hash(self, keys, indices):
|
505
|
+
def get_buffer_with_hash(self, keys, indices=None):
|
505
506
|
assert self.layout == "page_first"
|
506
|
-
assert len(keys) == (len(indices) // self.page_size)
|
507
|
+
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
507
508
|
|
508
509
|
key_list = []
|
509
510
|
buf_list = []
|
510
511
|
|
511
|
-
for
|
512
|
+
for i in range(len(keys)):
|
513
|
+
key = keys[i]
|
512
514
|
key_list.append(f"{key}-k")
|
513
|
-
buf_list.append(self.k_buffer[i : i + self.page_size])
|
514
515
|
key_list.append(f"{key}-v")
|
515
|
-
|
516
|
+
if indices is not None:
|
517
|
+
index = indices[i * self.page_size]
|
518
|
+
buf_list.append(self.k_buffer[index : index + self.page_size])
|
519
|
+
buf_list.append(self.v_buffer[index : index + self.page_size])
|
516
520
|
|
517
|
-
return key_list, buf_list
|
521
|
+
return key_list, buf_list, 2
|
518
522
|
|
519
523
|
|
520
524
|
class MLATokenToKVPoolHost(HostKVCache):
|
@@ -704,10 +708,11 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
704
708
|
else:
|
705
709
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
706
710
|
|
707
|
-
def get_buffer_meta(self, keys, indices):
|
711
|
+
def get_buffer_meta(self, keys, indices, local_rank):
|
708
712
|
ptr_list = []
|
709
713
|
key_list = []
|
710
714
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
715
|
+
indices = indices.tolist()
|
711
716
|
for index in range(0, len(indices), self.page_size):
|
712
717
|
k_ptr = (
|
713
718
|
kv_buffer_data_ptr
|
@@ -728,13 +733,15 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
728
733
|
element_size_list = [element_size] * len(key_list)
|
729
734
|
return key_list, ptr_list, element_size_list
|
730
735
|
|
731
|
-
def get_buffer_with_hash(self, keys, indices):
|
736
|
+
def get_buffer_with_hash(self, keys, indices=None):
|
732
737
|
assert self.layout == "page_first"
|
733
|
-
assert len(keys) == (len(indices) // self.page_size)
|
738
|
+
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
734
739
|
|
735
740
|
buf_list = []
|
736
741
|
|
737
|
-
|
738
|
-
|
742
|
+
if indices is not None:
|
743
|
+
for i in range(len(keys)):
|
744
|
+
index = indices[i * self.page_size]
|
745
|
+
buf_list.append(self.kv_buffer[index : index + self.page_size])
|
739
746
|
|
740
|
-
return keys, buf_list
|
747
|
+
return keys, buf_list, 1
|
@@ -53,8 +53,6 @@ class TreeNode:
|
|
53
53
|
self.last_access_time = time.monotonic()
|
54
54
|
|
55
55
|
self.hit_count = 0
|
56
|
-
# indicating the node is loading KV cache from host
|
57
|
-
self.loading = False
|
58
56
|
# indicating the node is locked to protect from eviction
|
59
57
|
# incremented when the node is referenced by a storage operation
|
60
58
|
self.host_ref_counter = 0
|
@@ -62,7 +60,6 @@ class TreeNode:
|
|
62
60
|
self.host_value: Optional[torch.Tensor] = None
|
63
61
|
# store hash values of each pages
|
64
62
|
self.hash_value: Optional[List[str]] = None
|
65
|
-
self.backuped_storage = False
|
66
63
|
|
67
64
|
self.id = TreeNode.counter if id is None else id
|
68
65
|
TreeNode.counter += 1
|
@@ -195,7 +192,7 @@ class RadixCache(BasePrefixCache):
|
|
195
192
|
last_host_node=last_node,
|
196
193
|
)
|
197
194
|
|
198
|
-
def insert(self, key: List, value=None):
|
195
|
+
def insert(self, key: List, value=None, chunked=False):
|
199
196
|
if self.disable:
|
200
197
|
return 0
|
201
198
|
|
@@ -240,7 +237,7 @@ class RadixCache(BasePrefixCache):
|
|
240
237
|
self.req_to_token_pool.free(req.req_pool_idx)
|
241
238
|
self.dec_lock_ref(req.last_node)
|
242
239
|
|
243
|
-
def cache_unfinished_req(self, req: Req):
|
240
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
244
241
|
"""Cache request when it is unfinished."""
|
245
242
|
if self.disable:
|
246
243
|
return
|
@@ -261,7 +258,9 @@ class RadixCache(BasePrefixCache):
|
|
261
258
|
page_aligned_token_ids = token_ids[:page_aligned_len]
|
262
259
|
|
263
260
|
# Radix Cache takes one ref in memory pool
|
264
|
-
new_prefix_len = self.insert(
|
261
|
+
new_prefix_len = self.insert(
|
262
|
+
page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
|
263
|
+
)
|
265
264
|
self.token_to_kv_pool_allocator.free(
|
266
265
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
267
266
|
)
|
@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
|
|
181
181
|
self.dec_lock_ref(req.last_node)
|
182
182
|
self.req_to_token_pool.free(req.req_pool_idx)
|
183
183
|
|
184
|
-
def cache_unfinished_req(self, req: Req):
|
184
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
185
185
|
"""Cache request when it is unfinished."""
|
186
186
|
assert req.req_pool_idx is not None
|
187
187
|
token_ids = req.fill_ids
|