sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,8 @@ from typing import Any, List, Optional
|
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
10
|
+
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
11
|
+
|
10
12
|
logger = logging.getLogger(__name__)
|
11
13
|
|
12
14
|
|
@@ -32,15 +34,46 @@ class HiCacheStorageConfig:
|
|
32
34
|
extra_config: Optional[dict] = None
|
33
35
|
|
34
36
|
|
37
|
+
@dataclass
|
38
|
+
class HiCacheStorageExtraInfo:
|
39
|
+
extra_info: Optional[dict] = None
|
40
|
+
|
41
|
+
|
35
42
|
class HiCacheStorage(ABC):
|
36
43
|
"""
|
37
44
|
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
|
38
45
|
It abstracts the underlying storage mechanism, allowing different implementations to be used.
|
39
46
|
"""
|
40
47
|
|
41
|
-
# todo, potentially pass model and TP configs into storage backend
|
42
48
|
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
|
43
49
|
|
50
|
+
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
|
51
|
+
self.mem_pool_host = mem_pool_host
|
52
|
+
|
53
|
+
def batch_get_v1(
|
54
|
+
self,
|
55
|
+
keys: List[str],
|
56
|
+
host_indices: torch.Tensor,
|
57
|
+
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
58
|
+
) -> List[bool]:
|
59
|
+
"""
|
60
|
+
Retrieve values for multiple keys.
|
61
|
+
Returns a list of tensors or None for each key.
|
62
|
+
"""
|
63
|
+
pass
|
64
|
+
|
65
|
+
def batch_set_v1(
|
66
|
+
self,
|
67
|
+
keys: List[str],
|
68
|
+
host_indices: torch.Tensor,
|
69
|
+
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
70
|
+
) -> List[bool]:
|
71
|
+
"""
|
72
|
+
Retrieve values for multiple keys.
|
73
|
+
Returns a list of tensors or None for each key.
|
74
|
+
"""
|
75
|
+
pass
|
76
|
+
|
44
77
|
@abstractmethod
|
45
78
|
def get(
|
46
79
|
self,
|
@@ -54,6 +87,7 @@ class HiCacheStorage(ABC):
|
|
54
87
|
"""
|
55
88
|
pass
|
56
89
|
|
90
|
+
# TODO: Deprecate
|
57
91
|
@abstractmethod
|
58
92
|
def batch_get(
|
59
93
|
self,
|
@@ -81,6 +115,7 @@ class HiCacheStorage(ABC):
|
|
81
115
|
"""
|
82
116
|
pass
|
83
117
|
|
118
|
+
# TODO: Deprecate
|
84
119
|
@abstractmethod
|
85
120
|
def batch_set(
|
86
121
|
self,
|
@@ -103,6 +138,7 @@ class HiCacheStorage(ABC):
|
|
103
138
|
"""
|
104
139
|
pass
|
105
140
|
|
141
|
+
# TODO: Use a finer-grained return type (e.g., List[bool])
|
106
142
|
def batch_exists(self, keys: List[str]) -> int:
|
107
143
|
"""
|
108
144
|
Check if the keys exist in the storage.
|
@@ -114,6 +150,9 @@ class HiCacheStorage(ABC):
|
|
114
150
|
return i
|
115
151
|
return len(keys)
|
116
152
|
|
153
|
+
def clear(self) -> None:
|
154
|
+
pass
|
155
|
+
|
117
156
|
def get_stats(self):
|
118
157
|
return None
|
119
158
|
|
@@ -1,8 +1,8 @@
|
|
1
1
|
import heapq
|
2
|
+
import json
|
2
3
|
import logging
|
3
4
|
import threading
|
4
5
|
import time
|
5
|
-
from queue import Queue
|
6
6
|
from typing import List, Optional
|
7
7
|
|
8
8
|
import torch
|
@@ -19,7 +19,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
|
|
19
19
|
MHATokenToKVPoolHost,
|
20
20
|
MLATokenToKVPoolHost,
|
21
21
|
)
|
22
|
-
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
22
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
23
23
|
from sglang.srt.metrics.collector import StorageMetricsCollector
|
24
24
|
|
25
25
|
logger = logging.getLogger(__name__)
|
@@ -39,17 +39,19 @@ class HiRadixCache(RadixCache):
|
|
39
39
|
hicache_io_backend: str,
|
40
40
|
hicache_mem_layout: str,
|
41
41
|
enable_metrics: bool,
|
42
|
+
eviction_policy: str = "lru",
|
42
43
|
hicache_storage_backend: Optional[str] = None,
|
43
44
|
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
44
45
|
model_name: Optional[str] = None,
|
45
46
|
storage_backend_extra_config: Optional[str] = None,
|
47
|
+
is_eagle: bool = False,
|
46
48
|
):
|
47
49
|
|
48
50
|
if hicache_io_backend == "direct":
|
49
51
|
if hicache_mem_layout == "page_first":
|
50
|
-
hicache_mem_layout = "
|
52
|
+
hicache_mem_layout = "page_first_direct"
|
51
53
|
logger.warning(
|
52
|
-
"Page first layout is not supported with direct IO backend, switching to
|
54
|
+
"Page first layout is not supported with direct IO backend, switching to page first direct layout"
|
53
55
|
)
|
54
56
|
|
55
57
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
@@ -77,9 +79,19 @@ class HiRadixCache(RadixCache):
|
|
77
79
|
self.enable_storage = hicache_storage_backend is not None
|
78
80
|
self.enable_storage_metrics = self.enable_storage and enable_metrics
|
79
81
|
|
80
|
-
|
81
|
-
|
82
|
-
|
82
|
+
(
|
83
|
+
extra_config,
|
84
|
+
prefetch_threshold,
|
85
|
+
prefetch_timeout_base,
|
86
|
+
prefetch_timeout_per_ki_token,
|
87
|
+
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
|
88
|
+
self.prefetch_threshold = prefetch_threshold
|
89
|
+
self.prefetch_timeout_base = prefetch_timeout_base
|
90
|
+
self.prefetch_timeout_per_page = (
|
91
|
+
page_size / 1024 * prefetch_timeout_per_ki_token
|
92
|
+
)
|
93
|
+
# TODO: support more timeout check functions
|
94
|
+
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
|
83
95
|
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
84
96
|
|
85
97
|
self.load_cache_event = threading.Event()
|
@@ -94,7 +106,7 @@ class HiRadixCache(RadixCache):
|
|
94
106
|
storage_backend=hicache_storage_backend,
|
95
107
|
prefetch_threshold=self.prefetch_threshold,
|
96
108
|
model_name=model_name,
|
97
|
-
storage_backend_extra_config=
|
109
|
+
storage_backend_extra_config=extra_config,
|
98
110
|
)
|
99
111
|
if self.enable_storage_metrics:
|
100
112
|
# TODO: support pp
|
@@ -117,8 +129,61 @@ class HiRadixCache(RadixCache):
|
|
117
129
|
1 if hicache_write_policy == "write_through" else 2
|
118
130
|
)
|
119
131
|
self.load_back_threshold = 10
|
132
|
+
|
120
133
|
super().__init__(
|
121
|
-
req_to_token_pool,
|
134
|
+
req_to_token_pool,
|
135
|
+
token_to_kv_pool_allocator,
|
136
|
+
page_size,
|
137
|
+
disable=False,
|
138
|
+
eviction_policy=eviction_policy,
|
139
|
+
is_eagle=is_eagle,
|
140
|
+
)
|
141
|
+
|
142
|
+
def _parse_storage_backend_extra_config(
|
143
|
+
self, storage_backend_extra_config: Optional[str]
|
144
|
+
):
|
145
|
+
"""
|
146
|
+
Parse storage backend extra config JSON and extract specific parameters.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
storage_backend_extra_config: JSON string containing extra configuration
|
150
|
+
|
151
|
+
Returns:
|
152
|
+
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
|
153
|
+
"""
|
154
|
+
# Parse extra config JSON if provided
|
155
|
+
extra_config = {}
|
156
|
+
if storage_backend_extra_config:
|
157
|
+
try:
|
158
|
+
extra_config = json.loads(storage_backend_extra_config)
|
159
|
+
except Exception as e:
|
160
|
+
logger.error(f"Invalid backend extra config JSON: {e}")
|
161
|
+
raise e
|
162
|
+
|
163
|
+
prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens
|
164
|
+
prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds
|
165
|
+
prefetch_timeout_per_ki_token = extra_config.pop(
|
166
|
+
"prefetch_timeout_per_ki_token", 0.25
|
167
|
+
) # seconds per 1024 tokens
|
168
|
+
|
169
|
+
if not isinstance(prefetch_threshold, int):
|
170
|
+
raise ValueError(
|
171
|
+
f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}"
|
172
|
+
)
|
173
|
+
if not isinstance(prefetch_timeout_base, (int, float)):
|
174
|
+
raise ValueError(
|
175
|
+
f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}"
|
176
|
+
)
|
177
|
+
if not isinstance(prefetch_timeout_per_ki_token, (int, float)):
|
178
|
+
raise ValueError(
|
179
|
+
f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}"
|
180
|
+
)
|
181
|
+
|
182
|
+
return (
|
183
|
+
extra_config,
|
184
|
+
prefetch_threshold,
|
185
|
+
float(prefetch_timeout_base),
|
186
|
+
float(prefetch_timeout_per_ki_token),
|
122
187
|
)
|
123
188
|
|
124
189
|
def reset(self):
|
@@ -258,12 +323,15 @@ class HiRadixCache(RadixCache):
|
|
258
323
|
|
259
324
|
def evict(self, num_tokens: int):
|
260
325
|
leaves = self._collect_leaves_device()
|
261
|
-
|
326
|
+
eviction_heap = [
|
327
|
+
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
328
|
+
]
|
329
|
+
heapq.heapify(eviction_heap)
|
262
330
|
|
263
331
|
num_evicted = 0
|
264
332
|
write_back_nodes = []
|
265
|
-
while num_evicted < num_tokens and len(
|
266
|
-
x = heapq.heappop(
|
333
|
+
while num_evicted < num_tokens and len(eviction_heap):
|
334
|
+
_priority, x = heapq.heappop(eviction_heap)
|
267
335
|
|
268
336
|
if x.lock_ref > 0:
|
269
337
|
continue
|
@@ -285,7 +353,8 @@ class HiRadixCache(RadixCache):
|
|
285
353
|
break
|
286
354
|
else:
|
287
355
|
# all children are evicted or no children
|
288
|
-
|
356
|
+
new_priority = self.eviction_strategy.get_priority(x.parent)
|
357
|
+
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
289
358
|
|
290
359
|
if self.cache_controller.write_policy == "write_back":
|
291
360
|
self.writing_check(write_back=True)
|
@@ -295,7 +364,7 @@ class HiRadixCache(RadixCache):
|
|
295
364
|
|
296
365
|
def _evict_backuped(self, node: TreeNode):
|
297
366
|
# evict a node already written to host
|
298
|
-
num_evicted = self.cache_controller.evict_device(node.value
|
367
|
+
num_evicted = self.cache_controller.evict_device(node.value)
|
299
368
|
assert num_evicted > 0
|
300
369
|
self.evictable_size_ -= num_evicted
|
301
370
|
node.value = None
|
@@ -310,11 +379,14 @@ class HiRadixCache(RadixCache):
|
|
310
379
|
|
311
380
|
def evict_host(self, num_tokens: int):
|
312
381
|
leaves = self._collect_leaves()
|
313
|
-
|
382
|
+
eviction_heap = [
|
383
|
+
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
384
|
+
]
|
385
|
+
heapq.heapify(eviction_heap)
|
314
386
|
|
315
387
|
num_evicted = 0
|
316
|
-
while num_evicted < num_tokens and len(
|
317
|
-
x = heapq.heappop(
|
388
|
+
while num_evicted < num_tokens and len(eviction_heap):
|
389
|
+
_priority, x = heapq.heappop(eviction_heap)
|
318
390
|
if x == self.root_node:
|
319
391
|
break
|
320
392
|
# only evict the host value of evicted nodes
|
@@ -333,7 +405,8 @@ class HiRadixCache(RadixCache):
|
|
333
405
|
del x.parent.children[k]
|
334
406
|
|
335
407
|
if len(x.parent.children) == 0 and x.parent.evicted:
|
336
|
-
|
408
|
+
new_priority = self.eviction_strategy.get_priority(x.parent)
|
409
|
+
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
337
410
|
|
338
411
|
def load_back(
|
339
412
|
self, node: TreeNode, mem_quota: Optional[int] = None
|
@@ -476,6 +549,15 @@ class HiRadixCache(RadixCache):
|
|
476
549
|
host_indices = torch.cat(host_indices_list, dim=0)
|
477
550
|
cc.mem_pool_host.free(host_indices)
|
478
551
|
|
552
|
+
# Timeout is linearly increasing with the number of pages
|
553
|
+
def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation):
|
554
|
+
# If hash_value has not been computed in timeout_base seconds, terminate it.
|
555
|
+
return (
|
556
|
+
time.monotonic() - operation.start_time
|
557
|
+
> self.prefetch_timeout_base
|
558
|
+
+ len(operation.hash_value) * self.prefetch_timeout_per_page
|
559
|
+
)
|
560
|
+
|
479
561
|
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
480
562
|
can_terminate = True
|
481
563
|
|
@@ -492,9 +574,7 @@ class HiRadixCache(RadixCache):
|
|
492
574
|
if self.prefetch_stop_policy == "wait_complete":
|
493
575
|
can_terminate = completed
|
494
576
|
elif self.prefetch_stop_policy == "timeout":
|
495
|
-
can_terminate = completed or (
|
496
|
-
time.monotonic() - operation.start_time > self.prefetch_timeout
|
497
|
-
)
|
577
|
+
can_terminate = completed or self.is_prefetch_timeout(operation)
|
498
578
|
else:
|
499
579
|
# unknown prefetch stop policy, just return True
|
500
580
|
return True
|
@@ -556,12 +636,12 @@ class HiRadixCache(RadixCache):
|
|
556
636
|
written_indices = host_indices[:min_completed_tokens]
|
557
637
|
matched_length = self._insert_helper_host(
|
558
638
|
last_host_node,
|
559
|
-
|
639
|
+
RadixKey(
|
640
|
+
token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
|
641
|
+
),
|
560
642
|
written_indices,
|
561
643
|
hash_value[: min_completed_tokens // self.page_size],
|
562
644
|
)
|
563
|
-
if len(written_indices):
|
564
|
-
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
565
645
|
|
566
646
|
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
567
647
|
self.cache_controller.append_host_mem_release(
|
@@ -578,8 +658,9 @@ class HiRadixCache(RadixCache):
|
|
578
658
|
|
579
659
|
return True
|
580
660
|
|
581
|
-
def match_prefix(self, key:
|
661
|
+
def match_prefix(self, key: RadixKey, **kwargs):
|
582
662
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
663
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
583
664
|
if self.disable or len(key) == 0:
|
584
665
|
return MatchResult(
|
585
666
|
device_indices=empty_value,
|
@@ -652,7 +733,9 @@ class HiRadixCache(RadixCache):
|
|
652
733
|
)
|
653
734
|
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
|
654
735
|
|
655
|
-
def _insert_helper_host(
|
736
|
+
def _insert_helper_host(
|
737
|
+
self, node: TreeNode, key: RadixKey, host_value, hash_value
|
738
|
+
):
|
656
739
|
node.last_access_time = time.monotonic()
|
657
740
|
if len(key) == 0:
|
658
741
|
return 0
|
@@ -686,7 +769,7 @@ class HiRadixCache(RadixCache):
|
|
686
769
|
node.children[child_key] = new_node
|
687
770
|
return matched_length
|
688
771
|
|
689
|
-
def _match_prefix_helper(self, node: TreeNode, key:
|
772
|
+
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
690
773
|
node.last_access_time = time.monotonic()
|
691
774
|
child_key = self.get_child_key_fn(key)
|
692
775
|
value = []
|
@@ -712,7 +795,7 @@ class HiRadixCache(RadixCache):
|
|
712
795
|
|
713
796
|
return value, node
|
714
797
|
|
715
|
-
def _split_node(self, key, child: TreeNode, split_len: int):
|
798
|
+
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
716
799
|
# child node split into new_node -> child
|
717
800
|
new_node = TreeNode()
|
718
801
|
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
@@ -739,10 +822,16 @@ class HiRadixCache(RadixCache):
|
|
739
822
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
740
823
|
return new_node
|
741
824
|
|
742
|
-
def insert(self, key:
|
825
|
+
def insert(self, key: RadixKey, value=None, chunked=False):
|
826
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
827
|
+
|
743
828
|
if len(key) == 0:
|
744
829
|
return 0
|
745
830
|
|
831
|
+
if self.is_eagle and value is not None:
|
832
|
+
# Make sure the value len equal to the EAGLE bigram key len
|
833
|
+
value = value[: len(key)]
|
834
|
+
|
746
835
|
node = self.root_node
|
747
836
|
child_key = self.get_child_key_fn(key)
|
748
837
|
total_prefix_length = 0
|
@@ -757,7 +846,6 @@ class HiRadixCache(RadixCache):
|
|
757
846
|
# change the reference if the node is evicted
|
758
847
|
# this often happens in the case of KV cache recomputation
|
759
848
|
node.value = value[:prefix_len]
|
760
|
-
self.token_to_kv_pool_host.update_synced(node.host_value)
|
761
849
|
self.evictable_size_ += len(node.value)
|
762
850
|
else:
|
763
851
|
self._inc_hit_count(node, chunked)
|
@@ -767,7 +855,6 @@ class HiRadixCache(RadixCache):
|
|
767
855
|
new_node = self._split_node(node.key, node, prefix_len)
|
768
856
|
if new_node.evicted:
|
769
857
|
new_node.value = value[:prefix_len]
|
770
|
-
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
771
858
|
self.evictable_size_ += len(new_node.value)
|
772
859
|
else:
|
773
860
|
self._inc_hit_count(new_node, chunked)
|
@@ -797,7 +884,7 @@ class HiRadixCache(RadixCache):
|
|
797
884
|
for idx in range(0, len(key), self.page_size):
|
798
885
|
new_node.hash_value.append(
|
799
886
|
self.cache_controller.get_hash_str(
|
800
|
-
key[idx : idx + self.page_size],
|
887
|
+
key.token_ids[idx : idx + self.page_size],
|
801
888
|
prior_hash=last_hash,
|
802
889
|
)
|
803
890
|
)
|
@@ -15,6 +15,8 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
from sglang.srt.layers.attention.nsa import index_buf_accessor
|
19
|
+
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
18
20
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
19
21
|
|
20
22
|
"""
|
@@ -1030,6 +1032,8 @@ class MLATokenToKVPool(KVCache):
|
|
1030
1032
|
enable_memory_saver: bool,
|
1031
1033
|
start_layer: Optional[int] = None,
|
1032
1034
|
end_layer: Optional[int] = None,
|
1035
|
+
use_nsa: bool = False,
|
1036
|
+
override_kv_cache_dim: Optional[int] = None,
|
1033
1037
|
):
|
1034
1038
|
super().__init__(
|
1035
1039
|
size,
|
@@ -1044,6 +1048,14 @@ class MLATokenToKVPool(KVCache):
|
|
1044
1048
|
|
1045
1049
|
self.kv_lora_rank = kv_lora_rank
|
1046
1050
|
self.qk_rope_head_dim = qk_rope_head_dim
|
1051
|
+
self.use_nsa = use_nsa
|
1052
|
+
self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn
|
1053
|
+
# TODO do not hardcode
|
1054
|
+
self.kv_cache_dim = (
|
1055
|
+
656
|
1056
|
+
if self.use_nsa and self.nsa_kv_cache_store_fp8
|
1057
|
+
else (kv_lora_rank + qk_rope_head_dim)
|
1058
|
+
)
|
1047
1059
|
|
1048
1060
|
# for disagg with nvlink
|
1049
1061
|
self.enable_custom_mem_pool = get_bool_env_var(
|
@@ -1067,7 +1079,7 @@ class MLATokenToKVPool(KVCache):
|
|
1067
1079
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
1068
1080
|
self.kv_buffer = [
|
1069
1081
|
torch.zeros(
|
1070
|
-
(size + page_size, 1,
|
1082
|
+
(size + page_size, 1, self.kv_cache_dim),
|
1071
1083
|
dtype=self.store_dtype,
|
1072
1084
|
device=device,
|
1073
1085
|
)
|
@@ -1130,6 +1142,7 @@ class MLATokenToKVPool(KVCache):
|
|
1130
1142
|
cache_v: torch.Tensor,
|
1131
1143
|
):
|
1132
1144
|
layer_id = layer.layer_id
|
1145
|
+
assert not (self.use_nsa and self.nsa_kv_cache_store_fp8)
|
1133
1146
|
if cache_k.dtype != self.dtype:
|
1134
1147
|
cache_k = cache_k.to(self.dtype)
|
1135
1148
|
if self.store_dtype != self.dtype:
|
@@ -1147,16 +1160,28 @@ class MLATokenToKVPool(KVCache):
|
|
1147
1160
|
cache_k_rope: torch.Tensor,
|
1148
1161
|
):
|
1149
1162
|
layer_id = layer.layer_id
|
1150
|
-
if cache_k_nope.dtype != self.dtype:
|
1151
|
-
cache_k_nope = cache_k_nope.to(self.dtype)
|
1152
|
-
cache_k_rope = cache_k_rope.to(self.dtype)
|
1153
|
-
if self.store_dtype != self.dtype:
|
1154
|
-
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
1155
|
-
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
1156
1163
|
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1164
|
+
if self.use_nsa and self.nsa_kv_cache_store_fp8:
|
1165
|
+
# original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here
|
1166
|
+
# TODO no need to cat
|
1167
|
+
cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1)
|
1168
|
+
cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1)
|
1169
|
+
cache_k = cache_k.view(self.store_dtype)
|
1170
|
+
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
|
1171
|
+
else:
|
1172
|
+
if cache_k_nope.dtype != self.dtype:
|
1173
|
+
cache_k_nope = cache_k_nope.to(self.dtype)
|
1174
|
+
cache_k_rope = cache_k_rope.to(self.dtype)
|
1175
|
+
if self.store_dtype != self.dtype:
|
1176
|
+
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
1177
|
+
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
1178
|
+
|
1179
|
+
set_mla_kv_buffer_triton(
|
1180
|
+
self.kv_buffer[layer_id - self.start_layer],
|
1181
|
+
loc,
|
1182
|
+
cache_k_nope,
|
1183
|
+
cache_k_rope,
|
1184
|
+
)
|
1160
1185
|
|
1161
1186
|
def get_cpu_copy(self, indices):
|
1162
1187
|
torch.cuda.synchronize()
|
@@ -1186,6 +1211,103 @@ class MLATokenToKVPool(KVCache):
|
|
1186
1211
|
torch.cuda.synchronize()
|
1187
1212
|
|
1188
1213
|
|
1214
|
+
class NSATokenToKVPool(MLATokenToKVPool):
|
1215
|
+
def __init__(
|
1216
|
+
self,
|
1217
|
+
size: int,
|
1218
|
+
page_size: int,
|
1219
|
+
kv_lora_rank: int,
|
1220
|
+
dtype: torch.dtype,
|
1221
|
+
qk_rope_head_dim: int,
|
1222
|
+
layer_num: int,
|
1223
|
+
device: str,
|
1224
|
+
index_head_dim: int,
|
1225
|
+
enable_memory_saver: bool,
|
1226
|
+
start_layer: Optional[int] = None,
|
1227
|
+
end_layer: Optional[int] = None,
|
1228
|
+
):
|
1229
|
+
super().__init__(
|
1230
|
+
size,
|
1231
|
+
page_size,
|
1232
|
+
dtype,
|
1233
|
+
kv_lora_rank,
|
1234
|
+
qk_rope_head_dim,
|
1235
|
+
layer_num,
|
1236
|
+
device,
|
1237
|
+
enable_memory_saver,
|
1238
|
+
start_layer,
|
1239
|
+
end_layer,
|
1240
|
+
use_nsa=True,
|
1241
|
+
)
|
1242
|
+
# self.index_k_dtype = torch.float8_e4m3fn
|
1243
|
+
# self.index_k_scale_dtype = torch.float32
|
1244
|
+
self.index_head_dim = index_head_dim
|
1245
|
+
# num head == 1 and head dim == 128 for index_k in NSA
|
1246
|
+
assert index_head_dim == 128
|
1247
|
+
|
1248
|
+
self.quant_block_size = 128
|
1249
|
+
|
1250
|
+
assert self.page_size == 64
|
1251
|
+
self.index_k_with_scale_buffer = [
|
1252
|
+
torch.zeros(
|
1253
|
+
# Layout:
|
1254
|
+
# ref: test_attention.py :: kv_cache_cast_to_fp8
|
1255
|
+
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
|
1256
|
+
# data: for page i,
|
1257
|
+
# * buf[i, :page_size * head_dim] for fp8 data
|
1258
|
+
# * buf[i, page_size * head_dim:].view(float32) for scale
|
1259
|
+
(
|
1260
|
+
(size + page_size + 1) // self.page_size,
|
1261
|
+
self.page_size
|
1262
|
+
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
|
1263
|
+
),
|
1264
|
+
dtype=torch.uint8,
|
1265
|
+
device=device,
|
1266
|
+
)
|
1267
|
+
for _ in range(layer_num)
|
1268
|
+
]
|
1269
|
+
|
1270
|
+
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
|
1271
|
+
if self.layer_transfer_counter is not None:
|
1272
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
1273
|
+
return self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
1274
|
+
|
1275
|
+
def get_index_k_continuous(
|
1276
|
+
self,
|
1277
|
+
layer_id: int,
|
1278
|
+
seq_len: int,
|
1279
|
+
page_indices: torch.Tensor,
|
1280
|
+
):
|
1281
|
+
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
1282
|
+
return index_buf_accessor.GetK.execute(
|
1283
|
+
self, buf, seq_len=seq_len, page_indices=page_indices
|
1284
|
+
)
|
1285
|
+
|
1286
|
+
def get_index_k_scale_continuous(
|
1287
|
+
self,
|
1288
|
+
layer_id: int,
|
1289
|
+
seq_len: int,
|
1290
|
+
page_indices: torch.Tensor,
|
1291
|
+
):
|
1292
|
+
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
1293
|
+
return index_buf_accessor.GetS.execute(
|
1294
|
+
self, buf, seq_len=seq_len, page_indices=page_indices
|
1295
|
+
)
|
1296
|
+
|
1297
|
+
# TODO rename later (currently use diff name to avoid confusion)
|
1298
|
+
def set_index_k_and_scale_buffer(
|
1299
|
+
self,
|
1300
|
+
layer_id: int,
|
1301
|
+
loc: torch.Tensor,
|
1302
|
+
index_k: torch.Tensor,
|
1303
|
+
index_k_scale: torch.Tensor,
|
1304
|
+
) -> None:
|
1305
|
+
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
1306
|
+
index_buf_accessor.SetKAndS.execute(
|
1307
|
+
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
|
1308
|
+
)
|
1309
|
+
|
1310
|
+
|
1189
1311
|
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
1190
1312
|
def __init__(
|
1191
1313
|
self,
|
@@ -1194,6 +1316,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
1194
1316
|
dtype: torch.dtype,
|
1195
1317
|
kv_lora_rank: int,
|
1196
1318
|
qk_rope_head_dim: int,
|
1319
|
+
index_head_dim: Optional[int],
|
1197
1320
|
layer_num: int,
|
1198
1321
|
device: str,
|
1199
1322
|
enable_memory_saver: bool,
|
@@ -1213,6 +1336,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
1213
1336
|
|
1214
1337
|
self.kv_lora_rank = kv_lora_rank
|
1215
1338
|
self.qk_rope_head_dim = qk_rope_head_dim
|
1339
|
+
self.index_head_dim = index_head_dim
|
1216
1340
|
|
1217
1341
|
self.custom_mem_pool = None
|
1218
1342
|
|
@@ -1240,6 +1364,18 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
1240
1364
|
dtype=self.store_dtype,
|
1241
1365
|
device=self.device,
|
1242
1366
|
)
|
1367
|
+
if self.index_head_dim is not None:
|
1368
|
+
self.index_k_buffer = torch.zeros(
|
1369
|
+
(
|
1370
|
+
layer_num,
|
1371
|
+
self.size // self.page_size + 1,
|
1372
|
+
self.page_size,
|
1373
|
+
1,
|
1374
|
+
self.index_head_dim,
|
1375
|
+
),
|
1376
|
+
dtype=self.store_dtype,
|
1377
|
+
device=self.device,
|
1378
|
+
)
|
1243
1379
|
|
1244
1380
|
self._finalize_allocation_log(size)
|
1245
1381
|
|
@@ -1251,6 +1387,10 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
1251
1387
|
kv_size_bytes += get_tensor_size_bytes(k_cache)
|
1252
1388
|
for v_cache in self.v_buffer:
|
1253
1389
|
kv_size_bytes += get_tensor_size_bytes(v_cache)
|
1390
|
+
if self.index_head_dim is not None:
|
1391
|
+
assert hasattr(self, "index_k_buffer")
|
1392
|
+
for index_k_cache in self.index_k_buffer:
|
1393
|
+
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
|
1254
1394
|
return kv_size_bytes
|
1255
1395
|
|
1256
1396
|
def get_kv_buffer(self, layer_id: int):
|
@@ -1277,6 +1417,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
1277
1417
|
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
1278
1418
|
return self.v_buffer[layer_id - self.start_layer]
|
1279
1419
|
|
1420
|
+
def get_index_k_buffer(self, layer_id: int):
|
1421
|
+
if self.layer_transfer_counter is not None:
|
1422
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
1423
|
+
|
1424
|
+
if self.store_dtype != self.dtype:
|
1425
|
+
return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)
|
1426
|
+
return self.index_k_buffer[layer_id - self.start_layer]
|
1427
|
+
|
1280
1428
|
# for disagg
|
1281
1429
|
def get_contiguous_buf_infos(self):
|
1282
1430
|
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
@@ -1289,6 +1437,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
1289
1437
|
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
|
1290
1438
|
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
|
1291
1439
|
]
|
1440
|
+
if self.index_head_dim is not None:
|
1441
|
+
kv_data_ptrs += [
|
1442
|
+
self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)
|
1443
|
+
]
|
1444
|
+
kv_data_lens += [
|
1445
|
+
self.index_k_buffer[i].nbytes for i in range(self.layer_num)
|
1446
|
+
]
|
1447
|
+
kv_item_lens += [
|
1448
|
+
self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)
|
1449
|
+
]
|
1292
1450
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
1293
1451
|
|
1294
1452
|
def set_kv_buffer(
|
@@ -1325,6 +1483,26 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
1325
1483
|
cache_v.view(-1, 1, self.qk_rope_head_dim),
|
1326
1484
|
)
|
1327
1485
|
|
1486
|
+
def set_index_k_buffer(
|
1487
|
+
self,
|
1488
|
+
layer_id: int,
|
1489
|
+
loc: torch.Tensor,
|
1490
|
+
index_k: torch.Tensor,
|
1491
|
+
):
|
1492
|
+
if index_k.dtype != self.dtype:
|
1493
|
+
index_k = index_k.to(self.dtype)
|
1494
|
+
|
1495
|
+
if self.store_dtype != self.dtype:
|
1496
|
+
index_k = index_k.view(self.store_dtype)
|
1497
|
+
|
1498
|
+
torch_npu.npu_scatter_nd_update_(
|
1499
|
+
self.index_k_buffer[layer_id - self.start_layer].view(
|
1500
|
+
-1, 1, self.index_head_dim
|
1501
|
+
),
|
1502
|
+
loc.view(-1, 1),
|
1503
|
+
index_k.view(-1, 1, self.index_head_dim),
|
1504
|
+
)
|
1505
|
+
|
1328
1506
|
|
1329
1507
|
class DoubleSparseTokenToKVPool(KVCache):
|
1330
1508
|
def __init__(
|