sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,115 @@
|
|
1
|
+
import argparse
|
2
|
+
import os
|
3
|
+
|
4
|
+
import eic
|
5
|
+
import torch
|
6
|
+
import yaml
|
7
|
+
|
8
|
+
|
9
|
+
def pase_args():
|
10
|
+
parser = argparse.ArgumentParser(description="EIC Storage Unit Test")
|
11
|
+
parser.add_argument(
|
12
|
+
"--config",
|
13
|
+
"-c",
|
14
|
+
type=str,
|
15
|
+
default="/sgl-workspace/config/remote-eic.yaml",
|
16
|
+
help="EIC yaml config",
|
17
|
+
)
|
18
|
+
args, _ = parser.parse_known_args()
|
19
|
+
return args
|
20
|
+
|
21
|
+
|
22
|
+
def init_eic_client():
|
23
|
+
args = pase_args()
|
24
|
+
config_path = os.path.abspath(args.config)
|
25
|
+
if not os.path.exists(config_path):
|
26
|
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
27
|
+
with open(config_path, "r") as fin:
|
28
|
+
config = yaml.safe_load(fin)
|
29
|
+
|
30
|
+
remote_url = config.get("remote_url", None)
|
31
|
+
if remote_url is None:
|
32
|
+
AssertionError("remote_url is None")
|
33
|
+
endpoint = remote_url[len("eic://") :]
|
34
|
+
eic_instance_id = config.get("eic_instance_id", None)
|
35
|
+
eic_log_dir = config.get("eic_log_dir", None)
|
36
|
+
eic_log_level = config.get("eic_log_level", 2)
|
37
|
+
eic_trans_type = config.get("eic_trans_type", 3)
|
38
|
+
eic_flag_file = config.get("eic_flag_file", None)
|
39
|
+
|
40
|
+
if not os.path.exists(eic_log_dir):
|
41
|
+
os.makedirs(eic_log_dir, exist_ok=True)
|
42
|
+
eic_client = eic.Client()
|
43
|
+
init_option = eic.InitOption()
|
44
|
+
init_option.log_dir = eic_log_dir
|
45
|
+
init_option.log_level = eic.LogLevel(eic_log_level)
|
46
|
+
init_option.transport_type = eic.TransportType(eic_trans_type)
|
47
|
+
init_option.flag_file = eic_flag_file
|
48
|
+
ret = eic_client.init(eic_instance_id, endpoint, init_option)
|
49
|
+
if ret != 0:
|
50
|
+
raise RuntimeError(f"EIC Client init failed with error code: {ret}")
|
51
|
+
return eic_client
|
52
|
+
|
53
|
+
|
54
|
+
def test_set(eic_client):
|
55
|
+
test_key = ["test_key_" + str(i) for i in range(16)]
|
56
|
+
tensors = [
|
57
|
+
torch.ones([12, 6, 1, 512], dtype=torch.bfloat16, device="cpu")
|
58
|
+
for _ in range(16)
|
59
|
+
]
|
60
|
+
data_keys = eic.StringVector()
|
61
|
+
data_vals = eic.IOBuffers()
|
62
|
+
for i in range(16):
|
63
|
+
data_keys.append(test_key[i])
|
64
|
+
data_vals.append(
|
65
|
+
tensors[i].data_ptr(), tensors[i].numel() * tensors[i].element_size(), False
|
66
|
+
)
|
67
|
+
set_opt = eic.SetOption()
|
68
|
+
set_opt.ttl_second = 3
|
69
|
+
status_code, set_outcome = eic_client.mset(data_keys, data_vals, set_opt)
|
70
|
+
assert (
|
71
|
+
status_code == eic.StatusCode.SUCCESS
|
72
|
+
), f"Set failed with status code: {status_code}"
|
73
|
+
|
74
|
+
|
75
|
+
def test_get(eic_client):
|
76
|
+
test_key = ["test_key_" + str(i) for i in range(16)]
|
77
|
+
tensors = [
|
78
|
+
torch.zeros([12, 6, 1, 512], dtype=torch.bfloat16, device="cpu")
|
79
|
+
for _ in range(16)
|
80
|
+
]
|
81
|
+
data_keys = eic.StringVector()
|
82
|
+
data_vals = eic.IOBuffers()
|
83
|
+
for i in range(16):
|
84
|
+
data_keys.append(test_key[i])
|
85
|
+
data_vals.append(
|
86
|
+
tensors[i].data_ptr(), tensors[i].numel() * tensors[i].element_size(), False
|
87
|
+
)
|
88
|
+
get_opt = eic.GetOption()
|
89
|
+
status_code, data_vals, get_outcome = eic_client.mget(data_keys, get_opt, data_vals)
|
90
|
+
assert (
|
91
|
+
status_code == eic.StatusCode.SUCCESS
|
92
|
+
), f"Get failed with status code: {status_code}"
|
93
|
+
|
94
|
+
|
95
|
+
def test_exists(eic_client):
|
96
|
+
test_key = ["test_key_" + str(i) for i in range(16)]
|
97
|
+
data_keys = eic.StringVector()
|
98
|
+
for key in test_key:
|
99
|
+
data_keys.append(key)
|
100
|
+
exists_opt = eic.ExistOption()
|
101
|
+
status_code, exists_outcome = eic_client.mexist(data_keys, exists_opt)
|
102
|
+
assert (
|
103
|
+
status_code == eic.StatusCode.SUCCESS
|
104
|
+
), f"Exists failed with status code: {status_code}"
|
105
|
+
|
106
|
+
|
107
|
+
def main():
|
108
|
+
eic_client = init_eic_client()
|
109
|
+
test_set(eic_client)
|
110
|
+
test_exists(eic_client)
|
111
|
+
test_get(eic_client)
|
112
|
+
|
113
|
+
|
114
|
+
if __name__ == "__main__":
|
115
|
+
main()
|
@@ -12,7 +12,12 @@ from typing import Any, List, Optional, Tuple
|
|
12
12
|
|
13
13
|
import torch
|
14
14
|
|
15
|
-
from sglang.srt.mem_cache.hicache_storage import
|
15
|
+
from sglang.srt.mem_cache.hicache_storage import (
|
16
|
+
HiCacheStorage,
|
17
|
+
HiCacheStorageConfig,
|
18
|
+
HiCacheStorageExtraInfo,
|
19
|
+
)
|
20
|
+
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
16
21
|
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
|
17
22
|
from sglang.srt.metrics.collector import StorageMetrics
|
18
23
|
|
@@ -178,11 +183,14 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
178
183
|
self.skip_backup = True
|
179
184
|
self.rank = 0
|
180
185
|
|
186
|
+
self.is_zero_copy = False
|
187
|
+
|
181
188
|
logger.info(
|
182
189
|
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
183
190
|
f"file_path={self.file_path}, "
|
184
191
|
f"file_size={self.file_size / (2 ** 30):.2f} GB, "
|
185
|
-
f"num_pages={self.num_pages}"
|
192
|
+
f"num_pages={self.num_pages}, "
|
193
|
+
f"is_mla_model={self.is_mla_model}"
|
186
194
|
)
|
187
195
|
|
188
196
|
self.ac = AtomicCounter(self.numjobs)
|
@@ -323,25 +331,12 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
323
331
|
use_mock_client=use_mock_client,
|
324
332
|
)
|
325
333
|
|
326
|
-
def get(
|
327
|
-
self,
|
328
|
-
key: str,
|
329
|
-
target_location: Optional[Any] = None,
|
330
|
-
target_sizes: Optional[Any] = None,
|
331
|
-
) -> torch.Tensor | None:
|
332
|
-
return self.batch_get(
|
333
|
-
[key],
|
334
|
-
[target_location] if target_location is not None else None,
|
335
|
-
[target_sizes] if target_sizes is not None else None,
|
336
|
-
)[0]
|
337
|
-
|
338
334
|
@synchronized()
|
339
|
-
def
|
335
|
+
def _batch_get(
|
340
336
|
self,
|
341
337
|
keys: List[str],
|
342
|
-
|
343
|
-
|
344
|
-
) -> List[torch.Tensor | None]:
|
338
|
+
values: List[torch.Tensor],
|
339
|
+
) -> List[bool]:
|
345
340
|
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
|
346
341
|
|
347
342
|
batch_indices, file_offsets = [], []
|
@@ -350,15 +345,9 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
350
345
|
batch_indices.append(i)
|
351
346
|
file_offsets.append(page_index * self.bytes_per_page)
|
352
347
|
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
file_results = target_locations
|
357
|
-
else:
|
358
|
-
file_results = [
|
359
|
-
torch.empty(self.numel, dtype=self.dtype)
|
360
|
-
for _ in range(len(batch_indices))
|
361
|
-
]
|
348
|
+
for target_location in values:
|
349
|
+
assert target_location.is_contiguous()
|
350
|
+
file_results = values
|
362
351
|
|
363
352
|
start_time = time.perf_counter()
|
364
353
|
|
@@ -379,12 +368,10 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
379
368
|
ionum / (end_time - start_time) * self.gb_per_page
|
380
369
|
)
|
381
370
|
|
382
|
-
results = [
|
383
|
-
for batch_index,
|
384
|
-
batch_indices, file_results, read_results
|
385
|
-
):
|
371
|
+
results = [False] * len(keys)
|
372
|
+
for batch_index, read_result in zip(batch_indices, read_results):
|
386
373
|
if read_result == self.bytes_per_page:
|
387
|
-
results[batch_index] =
|
374
|
+
results[batch_index] = True
|
388
375
|
else:
|
389
376
|
logger.error(
|
390
377
|
f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
|
@@ -392,28 +379,12 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
392
379
|
|
393
380
|
return results
|
394
381
|
|
395
|
-
def set(
|
396
|
-
self,
|
397
|
-
key: str,
|
398
|
-
value: Optional[Any] = None,
|
399
|
-
target_location: Optional[Any] = None,
|
400
|
-
target_sizes: Optional[Any] = None,
|
401
|
-
) -> bool:
|
402
|
-
return self.batch_set(
|
403
|
-
[key],
|
404
|
-
[value] if value is not None else None,
|
405
|
-
[target_location] if target_location is not None else None,
|
406
|
-
[target_sizes] if target_sizes is not None else None,
|
407
|
-
)
|
408
|
-
|
409
382
|
@synchronized()
|
410
|
-
def
|
383
|
+
def _batch_set(
|
411
384
|
self,
|
412
385
|
keys: List[str],
|
413
386
|
values: Optional[Any] = None,
|
414
|
-
|
415
|
-
target_sizes: Optional[Any] = None,
|
416
|
-
) -> bool:
|
387
|
+
) -> List[bool]:
|
417
388
|
# In MLA backend, only one rank needs to backup the KV cache
|
418
389
|
if self.skip_backup:
|
419
390
|
return True
|
@@ -474,7 +445,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
474
445
|
self.rank, written_keys_to_confirm, pages_to_release
|
475
446
|
)
|
476
447
|
|
477
|
-
return
|
448
|
+
return results
|
478
449
|
|
479
450
|
def delete(self, key: str) -> None:
|
480
451
|
self.metadata_client.delete_keys(self.rank, [key])
|
@@ -484,21 +455,25 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
484
455
|
return result[0] if result else False
|
485
456
|
|
486
457
|
def batch_exists(self, keys: List[str]) -> int:
|
458
|
+
factor = 1
|
459
|
+
if self.is_zero_copy and not self.is_mla_model:
|
460
|
+
keys = self._get_mha_zero_copy_keys(keys)
|
461
|
+
factor = 2
|
462
|
+
|
487
463
|
results = self.metadata_client.exists(self.rank, keys)
|
488
|
-
for i in range(len(keys)):
|
489
|
-
if not results[i]:
|
490
|
-
return i
|
491
464
|
|
492
|
-
|
465
|
+
i = 0
|
466
|
+
while i < len(keys) and results[i]:
|
467
|
+
i += 1
|
493
468
|
|
494
|
-
|
469
|
+
return i // factor
|
470
|
+
|
471
|
+
def clear(self) -> None:
|
495
472
|
try:
|
496
473
|
self.metadata_client.clear(self.rank)
|
497
474
|
logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
|
498
|
-
return True
|
499
475
|
except Exception as e:
|
500
476
|
logger.error(f"Failed to clear HiCacheHF3FS: {e}")
|
501
|
-
return False
|
502
477
|
|
503
478
|
def close(self) -> None:
|
504
479
|
try:
|
@@ -521,3 +496,143 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
521
496
|
self.prefetch_bandwidth.clear()
|
522
497
|
self.backup_bandwidth.clear()
|
523
498
|
return storage_metrics
|
499
|
+
|
500
|
+
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
|
501
|
+
super().register_mem_pool_host(mem_pool_host)
|
502
|
+
self.is_zero_copy = self.mem_pool_host.layout == "page_first"
|
503
|
+
logger.info(f"{self.is_zero_copy=}")
|
504
|
+
|
505
|
+
def _get_mha_zero_copy_keys(self, keys: List[str]) -> List[str]:
|
506
|
+
_keys = []
|
507
|
+
for k in keys:
|
508
|
+
_keys.append(f"{k}-k")
|
509
|
+
_keys.append(f"{k}-v")
|
510
|
+
return _keys
|
511
|
+
|
512
|
+
def _get_mha_zero_copy_values(
|
513
|
+
self, values: List[torch.Tensor]
|
514
|
+
) -> List[torch.Tensor]:
|
515
|
+
_values = []
|
516
|
+
for value in values:
|
517
|
+
_values.append(value[0])
|
518
|
+
_values.append(value[1])
|
519
|
+
return _values
|
520
|
+
|
521
|
+
def _batch_get_preprocess(self, keys, host_indices):
|
522
|
+
page_num = len(host_indices) // self.mem_pool_host.page_size
|
523
|
+
# host_indices to kv_buffer
|
524
|
+
flat = not self.is_zero_copy
|
525
|
+
values = (
|
526
|
+
[
|
527
|
+
self.mem_pool_host.get_data_page(
|
528
|
+
host_indices[i * self.mem_pool_host.page_size], flat=flat
|
529
|
+
)
|
530
|
+
for i in range(page_num)
|
531
|
+
]
|
532
|
+
if self.is_zero_copy
|
533
|
+
else [
|
534
|
+
self.mem_pool_host.get_dummy_flat_data_page() for _ in range(page_num)
|
535
|
+
]
|
536
|
+
)
|
537
|
+
|
538
|
+
if self.is_zero_copy and not self.is_mla_model:
|
539
|
+
keys = self._get_mha_zero_copy_keys(keys)
|
540
|
+
values = self._get_mha_zero_copy_values(values)
|
541
|
+
|
542
|
+
return keys, values
|
543
|
+
|
544
|
+
def _batch_get_postprocess(self, host_indices, values, results):
|
545
|
+
page_num = len(host_indices) // self.mem_pool_host.page_size
|
546
|
+
|
547
|
+
if self.is_zero_copy:
|
548
|
+
if not self.is_mla_model:
|
549
|
+
results = [
|
550
|
+
(results[2 * i] and results[2 * i + 1]) for i in range(page_num)
|
551
|
+
]
|
552
|
+
results = results[:page_num]
|
553
|
+
return results
|
554
|
+
|
555
|
+
for i in range(page_num):
|
556
|
+
if not results[i]:
|
557
|
+
break
|
558
|
+
self.mem_pool_host.set_from_flat_data_page(
|
559
|
+
host_indices[i * self.mem_pool_host.page_size], values[i]
|
560
|
+
)
|
561
|
+
|
562
|
+
return results
|
563
|
+
|
564
|
+
def batch_get_v1(
|
565
|
+
self,
|
566
|
+
keys: List[str],
|
567
|
+
host_indices: torch.Tensor,
|
568
|
+
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
569
|
+
) -> List[bool]:
|
570
|
+
keys, values = self._batch_get_preprocess(keys, host_indices)
|
571
|
+
results = self._batch_get(keys, values)
|
572
|
+
return self._batch_get_postprocess(host_indices, values, results)
|
573
|
+
|
574
|
+
def _batch_set_preprocess(self, keys, host_indices):
|
575
|
+
page_num = len(host_indices) // self.mem_pool_host.page_size
|
576
|
+
# host_indices to kv_buffer
|
577
|
+
flat = not self.is_zero_copy
|
578
|
+
values = [
|
579
|
+
self.mem_pool_host.get_data_page(
|
580
|
+
host_indices[i * self.mem_pool_host.page_size], flat=flat
|
581
|
+
)
|
582
|
+
for i in range(page_num)
|
583
|
+
]
|
584
|
+
|
585
|
+
if self.is_zero_copy and not self.is_mla_model:
|
586
|
+
keys = self._get_mha_zero_copy_keys(keys)
|
587
|
+
values = self._get_mha_zero_copy_values(values)
|
588
|
+
|
589
|
+
return keys, values
|
590
|
+
|
591
|
+
def batch_set_v1(
|
592
|
+
self,
|
593
|
+
keys: List[str],
|
594
|
+
host_indices: torch.Tensor,
|
595
|
+
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
596
|
+
) -> List[bool]:
|
597
|
+
len_keys = len(keys)
|
598
|
+
keys, values = self._batch_set_preprocess(keys, host_indices)
|
599
|
+
results = self._batch_set(keys, values)
|
600
|
+
return results
|
601
|
+
|
602
|
+
# Deprecated
|
603
|
+
def get(
|
604
|
+
self,
|
605
|
+
key: str,
|
606
|
+
target_location: Optional[Any] = None,
|
607
|
+
target_sizes: Optional[Any] = None,
|
608
|
+
) -> torch.Tensor | None:
|
609
|
+
pass
|
610
|
+
|
611
|
+
# Deprecated
|
612
|
+
def batch_get(
|
613
|
+
self,
|
614
|
+
keys: List[str],
|
615
|
+
target_locations: Optional[Any] = None,
|
616
|
+
target_sizes: Optional[Any] = None,
|
617
|
+
) -> List[torch.Tensor | None] | int:
|
618
|
+
pass
|
619
|
+
|
620
|
+
# Deprecated
|
621
|
+
def set(
|
622
|
+
self,
|
623
|
+
key: str,
|
624
|
+
value: Optional[Any] = None,
|
625
|
+
target_location: Optional[Any] = None,
|
626
|
+
target_sizes: Optional[Any] = None,
|
627
|
+
) -> bool:
|
628
|
+
pass
|
629
|
+
|
630
|
+
# Deprecated
|
631
|
+
def batch_set(
|
632
|
+
self,
|
633
|
+
keys: List[str],
|
634
|
+
values: Optional[Any] = None,
|
635
|
+
target_locations: Optional[Any] = None,
|
636
|
+
target_sizes: Optional[Any] = None,
|
637
|
+
) -> bool:
|
638
|
+
pass
|
@@ -9,7 +9,7 @@ import torch
|
|
9
9
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
10
10
|
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
11
11
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
12
|
-
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
12
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
13
13
|
|
14
14
|
try:
|
15
15
|
from lmcache.integration.sglang.sglang_adapter import (
|
@@ -78,6 +78,7 @@ class LMCRadixCache(RadixCache):
|
|
78
78
|
tp_size: int = 1,
|
79
79
|
rank: int = 0,
|
80
80
|
tp_group: Optional[torch.distributed.ProcessGroup] = None,
|
81
|
+
eviction_policy: str = "lru",
|
81
82
|
):
|
82
83
|
super().__init__(
|
83
84
|
req_to_token_pool=req_to_token_pool,
|
@@ -85,6 +86,7 @@ class LMCRadixCache(RadixCache):
|
|
85
86
|
page_size=page_size,
|
86
87
|
disable=disable,
|
87
88
|
enable_kv_cache_events=enable_kv_cache_events,
|
89
|
+
eviction_policy=eviction_policy,
|
88
90
|
)
|
89
91
|
|
90
92
|
kvcache = self.token_to_kv_pool_allocator.get_kvcache()
|
@@ -129,7 +131,7 @@ class LMCRadixCache(RadixCache):
|
|
129
131
|
with self._node_lock:
|
130
132
|
self._in_flight_nodes.clear()
|
131
133
|
|
132
|
-
def match_prefix(self, key:
|
134
|
+
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # type: ignore[override]
|
133
135
|
"""Match cached prefix; if there's a tail miss, prefetch from LMCache.
|
134
136
|
|
135
137
|
Reuses the base matching logic to obtain (value, last_node). If there
|
@@ -176,7 +178,7 @@ class LMCRadixCache(RadixCache):
|
|
176
178
|
with torch.cuda.stream(self.load_stream):
|
177
179
|
num_retrieved = self.lmcache_connector.start_load_kv(
|
178
180
|
LoadMetadata(
|
179
|
-
token_ids=key, # full page-aligned key
|
181
|
+
token_ids=key.token_ids, # full page-aligned key
|
180
182
|
slot_mapping=slot_mapping,
|
181
183
|
offset=value.numel() - prefix_pad, # LMCache offset convention
|
182
184
|
)
|
@@ -225,7 +227,7 @@ class LMCRadixCache(RadixCache):
|
|
225
227
|
req.req_pool_idx, : len(token_ids)
|
226
228
|
]
|
227
229
|
|
228
|
-
_, new_last_node, _, _ = self.match_prefix(token_ids)
|
230
|
+
_, new_last_node, _, _ = self.match_prefix(RadixKey(token_ids, req.extra_key))
|
229
231
|
assert new_last_node is not None
|
230
232
|
|
231
233
|
self.inc_lock_ref(new_last_node)
|
@@ -275,6 +277,8 @@ if __name__ == "__main__":
|
|
275
277
|
rank=0,
|
276
278
|
tp_group=None,
|
277
279
|
)
|
278
|
-
cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64))
|
279
|
-
cache.insert(
|
280
|
+
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 11, 12], dtype=torch.int64))
|
281
|
+
cache.insert(
|
282
|
+
RadixKey([1, 2, 3, 4]), torch.tensor([10, 11, 12, 13], dtype=torch.int64)
|
283
|
+
)
|
280
284
|
cache.pretty_print()
|
@@ -7,11 +7,16 @@ from typing import Any, List, Optional
|
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
10
|
-
from sglang.srt.mem_cache.hicache_storage import
|
10
|
+
from sglang.srt.mem_cache.hicache_storage import (
|
11
|
+
HiCacheStorage,
|
12
|
+
HiCacheStorageConfig,
|
13
|
+
HiCacheStorageExtraInfo,
|
14
|
+
)
|
15
|
+
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
11
16
|
|
12
17
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
13
18
|
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
|
14
|
-
|
19
|
+
DEFAULT_MOONCAKE_CONFIG_PATH_ENV = "SGLANG_HICACHE_MOONCAKE_CONFIG_PATH"
|
15
20
|
logger = logging.getLogger(__name__)
|
16
21
|
|
17
22
|
|
@@ -28,13 +33,13 @@ class MooncakeStoreConfig:
|
|
28
33
|
@staticmethod
|
29
34
|
def from_file() -> "MooncakeStoreConfig":
|
30
35
|
"""Load the config from a JSON file."""
|
31
|
-
file_path = os.getenv(
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
36
|
+
file_path = os.getenv(DEFAULT_MOONCAKE_CONFIG_PATH_ENV)
|
37
|
+
try:
|
38
|
+
with open(file_path) as fin:
|
39
|
+
config = json.load(fin)
|
40
|
+
except Exception as e:
|
41
|
+
raise RuntimeError(f"Failed to load config from {file_path}: {str(e)}")
|
42
|
+
|
38
43
|
return MooncakeStoreConfig(
|
39
44
|
local_hostname=config.get("local_hostname"),
|
40
45
|
metadata_server=config.get("metadata_server"),
|
@@ -101,6 +106,7 @@ class MooncakeStoreConfig:
|
|
101
106
|
|
102
107
|
|
103
108
|
class MooncakeStore(HiCacheStorage):
|
109
|
+
|
104
110
|
def __init__(self, storage_config: HiCacheStorageConfig = None):
|
105
111
|
try:
|
106
112
|
from mooncake.store import MooncakeDistributedStore
|
@@ -129,6 +135,10 @@ class MooncakeStore(HiCacheStorage):
|
|
129
135
|
logger.info(
|
130
136
|
"Mooncake Configuration loaded from extra_config successfully."
|
131
137
|
)
|
138
|
+
elif os.getenv(DEFAULT_MOONCAKE_CONFIG_PATH_ENV):
|
139
|
+
# Load from config file
|
140
|
+
self.config = MooncakeStoreConfig.from_file()
|
141
|
+
logger.info("Mooncake Configuration loaded from file successfully.")
|
132
142
|
else:
|
133
143
|
# Load from environment variables
|
134
144
|
self.config = MooncakeStoreConfig.load_from_env()
|
@@ -178,7 +188,13 @@ class MooncakeStore(HiCacheStorage):
|
|
178
188
|
assert self.store.is_exist(warmup_key) == 1
|
179
189
|
assert self.store.get(warmup_key) == warmup_value
|
180
190
|
|
181
|
-
def
|
191
|
+
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
|
192
|
+
super().register_mem_pool_host(mem_pool_host)
|
193
|
+
assert self.mem_pool_host.layout in [
|
194
|
+
"page_first",
|
195
|
+
"page_first_direct",
|
196
|
+
], "mooncake store storage backend only support page first or page first direct layout"
|
197
|
+
buffer = self.mem_pool_host.kv_buffer
|
182
198
|
try:
|
183
199
|
buffer_ptr = buffer.data_ptr()
|
184
200
|
buffer_size = buffer.numel() * buffer.element_size()
|
@@ -189,6 +205,97 @@ class MooncakeStore(HiCacheStorage):
|
|
189
205
|
logger.error("Failed to register buffer to Mooncake Store: %s", err)
|
190
206
|
raise TypeError("Mooncake Store Register Buffer Error.") from err
|
191
207
|
|
208
|
+
def _get_mha_buffer_meta(self, keys, indices):
|
209
|
+
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
|
210
|
+
key_list = []
|
211
|
+
for key_ in keys:
|
212
|
+
key_list.append(f"{key_}_{self.local_rank}_k")
|
213
|
+
key_list.append(f"{key_}_{self.local_rank}_v")
|
214
|
+
assert len(key_list) == len(ptr_list)
|
215
|
+
return key_list, ptr_list, element_size_list
|
216
|
+
|
217
|
+
def _get_mla_buffer_meta(self, keys, indices):
|
218
|
+
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
|
219
|
+
key_list = []
|
220
|
+
for key_ in keys:
|
221
|
+
key_list.append(f"{key_}_k")
|
222
|
+
assert len(key_list) == len(ptr_list)
|
223
|
+
return key_list, ptr_list, element_size_list
|
224
|
+
|
225
|
+
def _batch_preprocess(self, keys, host_indices):
|
226
|
+
assert len(keys) > 0
|
227
|
+
assert len(keys) == len(host_indices) // self.mem_pool_host.page_size
|
228
|
+
if self.is_mla_backend:
|
229
|
+
return self._get_mla_buffer_meta(keys, host_indices)
|
230
|
+
else:
|
231
|
+
return self._get_mha_buffer_meta(keys, host_indices)
|
232
|
+
|
233
|
+
def _batch_postprocess(self, results: List[int], is_set_operate=False):
|
234
|
+
"""
|
235
|
+
refer to https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-store/include/pybind_client.h
|
236
|
+
for batch_get_into, results is Vector of integers,
|
237
|
+
where each element is the number of bytes read on success, or a negative value on error
|
238
|
+
for batch_put_from, results is Vector of integers,
|
239
|
+
where each element is 0 on success, or a negative value on error
|
240
|
+
"""
|
241
|
+
if self.is_mla_backend:
|
242
|
+
return [k_res == 0 if is_set_operate else k_res > 0 for k_res in results]
|
243
|
+
else:
|
244
|
+
kv_pairs = zip(results[::2], results[1::2])
|
245
|
+
return [
|
246
|
+
(
|
247
|
+
(k_res == 0 and v_res == 0)
|
248
|
+
if is_set_operate
|
249
|
+
else (k_res > 0 and v_res > 0)
|
250
|
+
)
|
251
|
+
for k_res, v_res in kv_pairs
|
252
|
+
]
|
253
|
+
|
254
|
+
def batch_get_v1(
|
255
|
+
self,
|
256
|
+
keys: List[str],
|
257
|
+
host_indices: torch.Tensor,
|
258
|
+
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
259
|
+
) -> List[bool]:
|
260
|
+
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
|
261
|
+
get_results = self._get_batch_zero_copy_impl(
|
262
|
+
key_strs, buffer_ptrs, buffer_sizes
|
263
|
+
)
|
264
|
+
return self._batch_postprocess(get_results, is_set_operate=False)
|
265
|
+
|
266
|
+
def batch_set_v1(
|
267
|
+
self,
|
268
|
+
keys: List[str],
|
269
|
+
host_indices: torch.Tensor,
|
270
|
+
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
271
|
+
) -> List[bool]:
|
272
|
+
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
|
273
|
+
exist_result = self._batch_exist(key_strs)
|
274
|
+
|
275
|
+
set_keys = []
|
276
|
+
set_buffer_ptrs = []
|
277
|
+
set_buffer_sizes = []
|
278
|
+
set_indices = []
|
279
|
+
set_results = [-1] * len(key_strs)
|
280
|
+
for i in range(len(key_strs)):
|
281
|
+
if exist_result[i] != 1:
|
282
|
+
set_keys.append(key_strs[i])
|
283
|
+
set_buffer_ptrs.append(buffer_ptrs[i])
|
284
|
+
set_buffer_sizes.append(buffer_sizes[i])
|
285
|
+
set_indices.append(i)
|
286
|
+
else:
|
287
|
+
set_results[i] = 0
|
288
|
+
|
289
|
+
# Only set non-existing keys to storage
|
290
|
+
if len(set_keys) > 0:
|
291
|
+
put_results = self._put_batch_zero_copy_impl(
|
292
|
+
set_keys, set_buffer_ptrs, set_buffer_sizes
|
293
|
+
)
|
294
|
+
for i in range(len(set_indices)):
|
295
|
+
set_results[set_indices[i]] = put_results[i]
|
296
|
+
|
297
|
+
return self._batch_postprocess(set_results, is_set_operate=True)
|
298
|
+
|
192
299
|
def set(
|
193
300
|
self,
|
194
301
|
key,
|