sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,121 @@
|
|
1
|
+
try:
|
2
|
+
from lmcache.integration.sglang.sglang_adapter import (
|
3
|
+
LMCacheLayerwiseConnector,
|
4
|
+
LoadMetadata,
|
5
|
+
StoreMetadata,
|
6
|
+
)
|
7
|
+
except ImportError:
|
8
|
+
raise RuntimeError(
|
9
|
+
"LMCache is not installed. Please install it by running `pip install lmcache` in the root directory of LMCache"
|
10
|
+
)
|
11
|
+
|
12
|
+
import os
|
13
|
+
|
14
|
+
import torch
|
15
|
+
|
16
|
+
from sglang.srt.configs.model_config import ModelConfig
|
17
|
+
|
18
|
+
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
|
19
|
+
os.environ["LMCACHE_CONFIG_FILE"] = "example_config.yaml"
|
20
|
+
|
21
|
+
|
22
|
+
def test_load_store_metadata():
|
23
|
+
model_config = ModelConfig(
|
24
|
+
model_path="Qwen/Qwen3-4B",
|
25
|
+
)
|
26
|
+
|
27
|
+
# Generate Dummy KV Cache
|
28
|
+
head_num = model_config.num_key_value_heads
|
29
|
+
head_dim = model_config.head_dim
|
30
|
+
layer_num = model_config.num_hidden_layers
|
31
|
+
buffer_size = 256
|
32
|
+
input_id_len = 16
|
33
|
+
|
34
|
+
k_buffer = [
|
35
|
+
torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
36
|
+
for _ in range(layer_num)
|
37
|
+
]
|
38
|
+
v_buffer = [
|
39
|
+
torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
40
|
+
for _ in range(layer_num)
|
41
|
+
]
|
42
|
+
|
43
|
+
connector = LMCacheLayerwiseConnector(model_config, 1, 0, k_buffer, v_buffer)
|
44
|
+
|
45
|
+
fake_token_ids = torch.randint(0, model_config.vocab_size, (input_id_len,)).tolist()
|
46
|
+
fake_kv_indices = torch.randint(0, buffer_size, (input_id_len,))
|
47
|
+
offset = 0
|
48
|
+
|
49
|
+
store_metadata = StoreMetadata(
|
50
|
+
last_node=None,
|
51
|
+
token_ids=fake_token_ids,
|
52
|
+
kv_indices=fake_kv_indices,
|
53
|
+
offset=offset,
|
54
|
+
)
|
55
|
+
|
56
|
+
load_metadata = LoadMetadata(
|
57
|
+
token_ids=fake_token_ids,
|
58
|
+
slot_mapping=fake_kv_indices,
|
59
|
+
offset=offset,
|
60
|
+
)
|
61
|
+
|
62
|
+
current_stream = torch.cuda.current_stream()
|
63
|
+
|
64
|
+
retrieve_token_num = connector.start_load_kv(load_metadata)
|
65
|
+
assert retrieve_token_num == 0
|
66
|
+
|
67
|
+
connector.store_kv(store_metadata)
|
68
|
+
current_stream.synchronize()
|
69
|
+
|
70
|
+
# check retrieve
|
71
|
+
gt_key_buffer = [
|
72
|
+
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
73
|
+
for _ in range(layer_num)
|
74
|
+
]
|
75
|
+
gt_value_buffer = [
|
76
|
+
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
77
|
+
for _ in range(layer_num)
|
78
|
+
]
|
79
|
+
|
80
|
+
for i in range(layer_num):
|
81
|
+
gt_key_buffer[i] = k_buffer[i][fake_kv_indices]
|
82
|
+
gt_value_buffer[i] = v_buffer[i][fake_kv_indices]
|
83
|
+
|
84
|
+
# clear the k_buffer and v_buffer
|
85
|
+
for _ in range(layer_num):
|
86
|
+
k_buffer[i].zero_()
|
87
|
+
v_buffer[i].zero_()
|
88
|
+
|
89
|
+
retrieve_token_num = connector.start_load_kv(load_metadata)
|
90
|
+
assert retrieve_token_num == input_id_len
|
91
|
+
|
92
|
+
for i in range(layer_num):
|
93
|
+
current_stream.synchronize()
|
94
|
+
connector.load_kv_layerwise(i)
|
95
|
+
|
96
|
+
current_stream.synchronize()
|
97
|
+
test_key_buffer = [
|
98
|
+
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
99
|
+
for _ in range(layer_num)
|
100
|
+
]
|
101
|
+
test_value_buffer = [
|
102
|
+
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
103
|
+
for _ in range(layer_num)
|
104
|
+
]
|
105
|
+
|
106
|
+
for i in range(layer_num):
|
107
|
+
test_key_buffer[i] = k_buffer[i][fake_kv_indices]
|
108
|
+
test_value_buffer[i] = v_buffer[i][fake_kv_indices]
|
109
|
+
|
110
|
+
for i in range(layer_num):
|
111
|
+
assert torch.allclose(test_key_buffer[i], gt_key_buffer[i])
|
112
|
+
assert torch.allclose(test_value_buffer[i], gt_value_buffer[i])
|
113
|
+
|
114
|
+
print("================================================")
|
115
|
+
print("TEST_LOAD_STORE_METADATA PASSED!")
|
116
|
+
print("================================================")
|
117
|
+
connector.close()
|
118
|
+
|
119
|
+
|
120
|
+
if __name__ == "__main__":
|
121
|
+
test_load_store_metadata()
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import hashlib
|
2
1
|
import json
|
3
2
|
import logging
|
4
3
|
import os
|
@@ -6,10 +5,8 @@ import uuid
|
|
6
5
|
from dataclasses import dataclass
|
7
6
|
from typing import Any, List, Optional
|
8
7
|
|
9
|
-
import numpy as np
|
10
8
|
import torch
|
11
9
|
|
12
|
-
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
13
10
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
14
11
|
|
15
12
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
@@ -75,6 +72,26 @@ class MooncakeStoreConfig:
|
|
75
72
|
master_server_address=os.getenv("MOONCAKE_MASTER"),
|
76
73
|
)
|
77
74
|
|
75
|
+
@staticmethod
|
76
|
+
def load_from_extra_config(extra_config: dict) -> "MooncakeStoreConfig":
|
77
|
+
"""Load config from extra_config dictionary."""
|
78
|
+
if "master_server_address" not in extra_config:
|
79
|
+
raise ValueError("master_server_address is required in extra_config")
|
80
|
+
|
81
|
+
return MooncakeStoreConfig(
|
82
|
+
local_hostname=extra_config.get("local_hostname", "localhost"),
|
83
|
+
metadata_server=extra_config.get("metadata_server", "P2PHANDSHAKE"),
|
84
|
+
global_segment_size=extra_config.get(
|
85
|
+
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
|
86
|
+
),
|
87
|
+
local_buffer_size=extra_config.get(
|
88
|
+
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
|
89
|
+
),
|
90
|
+
protocol=extra_config.get("protocol", "tcp"),
|
91
|
+
device_name=extra_config.get("device_name", "auto"),
|
92
|
+
master_server_address=extra_config["master_server_address"],
|
93
|
+
)
|
94
|
+
|
78
95
|
def __post_init__(self):
|
79
96
|
if self.device_name == "auto":
|
80
97
|
os.environ["MC_MS_AUTO_DISC"] = "1"
|
@@ -96,8 +113,26 @@ class MooncakeStore(HiCacheStorage):
|
|
96
113
|
|
97
114
|
try:
|
98
115
|
self.store = MooncakeDistributedStore()
|
99
|
-
|
100
|
-
|
116
|
+
|
117
|
+
extra_config = (
|
118
|
+
getattr(storage_config, "extra_config", None)
|
119
|
+
if storage_config
|
120
|
+
else None
|
121
|
+
)
|
122
|
+
# Load configuration with master_server_address prioritized from extra_config if available
|
123
|
+
if (
|
124
|
+
extra_config is not None
|
125
|
+
and extra_config.get("master_server_address") is not None
|
126
|
+
):
|
127
|
+
# Load from extra_config
|
128
|
+
self.config = MooncakeStoreConfig.load_from_extra_config(extra_config)
|
129
|
+
logger.info(
|
130
|
+
"Mooncake Configuration loaded from extra_config successfully."
|
131
|
+
)
|
132
|
+
else:
|
133
|
+
# Load from environment variables
|
134
|
+
self.config = MooncakeStoreConfig.load_from_env()
|
135
|
+
logger.info("Mooncake Configuration loaded from env successfully.")
|
101
136
|
|
102
137
|
ret_code = self.store.setup(
|
103
138
|
self.config.local_hostname,
|
@@ -154,20 +189,36 @@ class MooncakeStore(HiCacheStorage):
|
|
154
189
|
target_location: Optional[List[int]] = None,
|
155
190
|
target_sizes: Optional[List[int]] = None,
|
156
191
|
) -> bool:
|
157
|
-
|
192
|
+
# Only support zero copy set for now
|
193
|
+
assert target_location is not None and target_sizes is not None
|
194
|
+
exist_result = self._batch_exist([key])
|
195
|
+
if exist_result[0] == 1:
|
196
|
+
return True
|
197
|
+
put_result = self._put_batch_zero_copy_impl(
|
198
|
+
[key], [target_location], [target_sizes]
|
199
|
+
)
|
200
|
+
return put_result[0] == 0
|
158
201
|
|
159
202
|
def batch_set(
|
160
203
|
self,
|
161
204
|
keys: List[str],
|
162
|
-
|
205
|
+
values: Optional[List[torch.Tensor]] = None,
|
206
|
+
target_locations: Optional[List[int]] = None,
|
163
207
|
target_sizes: Optional[List[int]] = None,
|
164
208
|
) -> bool:
|
165
|
-
|
209
|
+
# Only support zero copy set for now
|
210
|
+
assert target_locations is not None and target_sizes is not None
|
211
|
+
assert len(keys) == len(target_locations) == len(target_sizes)
|
212
|
+
|
166
213
|
if len(keys) == 0:
|
167
214
|
return False
|
168
215
|
|
169
216
|
for i in range(len(keys)):
|
170
|
-
if
|
217
|
+
if (
|
218
|
+
keys[i] is None
|
219
|
+
or target_locations[i] is None
|
220
|
+
or target_sizes[i] is None
|
221
|
+
):
|
171
222
|
return False
|
172
223
|
|
173
224
|
exist_result = self._batch_exist(keys)
|
@@ -178,7 +229,7 @@ class MooncakeStore(HiCacheStorage):
|
|
178
229
|
for i in range(len(keys)):
|
179
230
|
if exist_result[i] != 1:
|
180
231
|
set_keys.append(keys[i])
|
181
|
-
set_target_locations.append(
|
232
|
+
set_target_locations.append(target_locations[i])
|
182
233
|
set_target_sizes.append(target_sizes[i])
|
183
234
|
set_indices.append(i)
|
184
235
|
# Only set non-existing keys to storage
|
@@ -203,18 +254,24 @@ class MooncakeStore(HiCacheStorage):
|
|
203
254
|
target_location: Optional[Any] = None,
|
204
255
|
target_sizes: Optional[Any] = None,
|
205
256
|
) -> bool:
|
206
|
-
|
257
|
+
assert target_location is not None and target_sizes is not None
|
258
|
+
get_result = self._get_batch_zero_copy_impl(
|
259
|
+
[key], [target_location], [target_sizes]
|
260
|
+
)
|
261
|
+
return get_result[0] >= 0
|
207
262
|
|
208
263
|
def batch_get(
|
209
264
|
self,
|
210
265
|
keys: List[str],
|
211
|
-
|
266
|
+
target_locations: Optional[Any] = None,
|
212
267
|
target_sizes: Optional[Any] = None,
|
213
268
|
) -> int:
|
214
|
-
assert len(keys) == len(
|
269
|
+
assert len(keys) == len(target_locations) == len(target_sizes)
|
215
270
|
if len(keys) == 0:
|
216
271
|
return 0
|
217
|
-
get_result = self._get_batch_zero_copy_impl(
|
272
|
+
get_result = self._get_batch_zero_copy_impl(
|
273
|
+
keys, target_locations, target_sizes
|
274
|
+
)
|
218
275
|
if self.is_mla_backend:
|
219
276
|
key_multiplier = 1
|
220
277
|
else:
|
@@ -225,7 +282,8 @@ class MooncakeStore(HiCacheStorage):
|
|
225
282
|
return len(keys) // key_multiplier
|
226
283
|
|
227
284
|
def exists(self, key) -> bool:
|
228
|
-
|
285
|
+
exist_result = self._batch_exist([key])
|
286
|
+
return exist_result[0] == 1
|
229
287
|
|
230
288
|
def batch_exists(self, keys) -> int:
|
231
289
|
if self.is_mla_backend:
|
@@ -244,16 +302,13 @@ class MooncakeStore(HiCacheStorage):
|
|
244
302
|
return i // key_multiplier
|
245
303
|
return len(query_keys) // key_multiplier
|
246
304
|
|
247
|
-
def delete(self, key) -> None:
|
248
|
-
raise (NotImplementedError)
|
249
|
-
|
250
305
|
def close(self):
|
251
306
|
# MooncakeDistributedStore will automatically call the destructor, so
|
252
307
|
# it is unnecessary to close it manually.
|
253
308
|
pass
|
254
309
|
|
255
310
|
def clear(self) -> None:
|
256
|
-
|
311
|
+
self.store.remove_all()
|
257
312
|
|
258
313
|
def _put_batch_zero_copy_impl(
|
259
314
|
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
@@ -0,0 +1,161 @@
|
|
1
|
+
import logging
|
2
|
+
import uuid
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from mooncake_store import MooncakeStore
|
6
|
+
|
7
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
8
|
+
|
9
|
+
logging.basicConfig(
|
10
|
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
11
|
+
)
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
def generate_batch_query_keys(kv_num: int, config: HiCacheStorageConfig):
|
16
|
+
keys = []
|
17
|
+
for _ in range(kv_num):
|
18
|
+
key = "test_" + str(uuid.uuid4())
|
19
|
+
keys.append(key)
|
20
|
+
set_keys = []
|
21
|
+
for key in keys:
|
22
|
+
if config.is_mla_model:
|
23
|
+
set_keys.append(key + "_k")
|
24
|
+
else:
|
25
|
+
set_keys.append(key + f"_{config.tp_rank}_k")
|
26
|
+
set_keys.append(key + f"_{config.tp_rank}_v")
|
27
|
+
get_keys = set_keys
|
28
|
+
exist_keys = keys
|
29
|
+
return set_keys, get_keys, exist_keys
|
30
|
+
|
31
|
+
|
32
|
+
def test_single_operation():
|
33
|
+
"""Test the set API with a single key-value pair."""
|
34
|
+
print("=" * 100)
|
35
|
+
print("Testing single operation")
|
36
|
+
|
37
|
+
buffer_size = 1024 * 1024 * 16 # 16MB
|
38
|
+
value_elements = 1024
|
39
|
+
store = MooncakeStore()
|
40
|
+
buffer = torch.randn(buffer_size, dtype=torch.float32)
|
41
|
+
store.register_buffer(buffer)
|
42
|
+
value_size = value_elements * buffer.element_size()
|
43
|
+
|
44
|
+
key = str(uuid.uuid4())
|
45
|
+
set_slice = buffer[:value_elements]
|
46
|
+
get_slice = buffer[value_elements : 2 * value_elements]
|
47
|
+
set_location = set_slice.data_ptr()
|
48
|
+
get_location = get_slice.data_ptr()
|
49
|
+
|
50
|
+
# Test set operation
|
51
|
+
result = store.set(key, target_location=set_location, target_sizes=value_size)
|
52
|
+
assert result is True, f"❌set operation failed for key: {key}"
|
53
|
+
|
54
|
+
# Test exists operation
|
55
|
+
assert store.exists(key), f"❌key {key} should exist after set operation"
|
56
|
+
|
57
|
+
# Test get operation
|
58
|
+
result = store.get(key, target_location=get_location, target_sizes=value_size)
|
59
|
+
assert result is True, f"❌get operation failed for key: {key}"
|
60
|
+
|
61
|
+
# Compare the data using proper tensor indices
|
62
|
+
assert torch.allclose(
|
63
|
+
set_slice, get_slice, atol=1e-6
|
64
|
+
), f"❌get operation failed for key: {key}"
|
65
|
+
|
66
|
+
logger.info(f"✅ Single operation passed")
|
67
|
+
|
68
|
+
|
69
|
+
def test_batch_operation(config: HiCacheStorageConfig):
|
70
|
+
"""Test the batch set/get APIs with multiple key-value pairs."""
|
71
|
+
print("=" * 100)
|
72
|
+
print(f"Testing batch operation with config: {config}")
|
73
|
+
|
74
|
+
buffer_size = 1024 * 1024 * 16 # 16MB
|
75
|
+
value_elements = 256
|
76
|
+
kv_num = 13
|
77
|
+
store = MooncakeStore(config)
|
78
|
+
buffer = torch.randn(buffer_size, dtype=torch.float32)
|
79
|
+
store.register_buffer(buffer)
|
80
|
+
value_size = value_elements * buffer.element_size()
|
81
|
+
|
82
|
+
set_keys, get_keys, exist_keys = generate_batch_query_keys(kv_num, config)
|
83
|
+
set_slices = [
|
84
|
+
buffer[i * value_elements : (i + 1) * value_elements]
|
85
|
+
for i in range(len(set_keys))
|
86
|
+
]
|
87
|
+
set_locations = [set_slice.data_ptr() for set_slice in set_slices]
|
88
|
+
target_sizes = [value_size for _ in range(len(set_keys))]
|
89
|
+
|
90
|
+
# Test batch set operation
|
91
|
+
result = store.batch_set(
|
92
|
+
set_keys, target_locations=set_locations, target_sizes=target_sizes
|
93
|
+
)
|
94
|
+
assert result is True, f"❌batch set operation failed"
|
95
|
+
|
96
|
+
# Test batch exists operation
|
97
|
+
assert store.batch_exists(
|
98
|
+
exist_keys
|
99
|
+
), f"❌keys should exist after batch set operation"
|
100
|
+
|
101
|
+
# Test batch get operation
|
102
|
+
get_slices = [
|
103
|
+
buffer[
|
104
|
+
(len(set_keys) + i)
|
105
|
+
* value_elements : (len(set_keys) + i + 1)
|
106
|
+
* value_elements
|
107
|
+
]
|
108
|
+
for i in range(len(get_keys))
|
109
|
+
]
|
110
|
+
get_locations = [get_slice.data_ptr() for get_slice in get_slices]
|
111
|
+
result = store.batch_get(
|
112
|
+
get_keys, target_locations=get_locations, target_sizes=target_sizes
|
113
|
+
)
|
114
|
+
assert result == kv_num, f"❌batch get operation failed"
|
115
|
+
for i in range(len(get_keys)):
|
116
|
+
assert torch.allclose(
|
117
|
+
set_slices[i], get_slices[i], atol=1e-6
|
118
|
+
), f"❌batch get operation failed for key: {get_keys[i]}"
|
119
|
+
|
120
|
+
logger.info(f"✅ Batch operation passed")
|
121
|
+
|
122
|
+
|
123
|
+
if __name__ == "__main__":
|
124
|
+
test_single_operation()
|
125
|
+
test_batch_operation(
|
126
|
+
HiCacheStorageConfig(
|
127
|
+
is_mla_model=False,
|
128
|
+
tp_rank=0,
|
129
|
+
tp_size=1,
|
130
|
+
model_name=None,
|
131
|
+
is_page_first_layout=True,
|
132
|
+
)
|
133
|
+
)
|
134
|
+
test_batch_operation(
|
135
|
+
HiCacheStorageConfig(
|
136
|
+
is_mla_model=True,
|
137
|
+
tp_rank=0,
|
138
|
+
tp_size=1,
|
139
|
+
model_name=None,
|
140
|
+
is_page_first_layout=True,
|
141
|
+
)
|
142
|
+
)
|
143
|
+
test_batch_operation(
|
144
|
+
HiCacheStorageConfig(
|
145
|
+
is_mla_model=False,
|
146
|
+
tp_rank=1,
|
147
|
+
tp_size=4,
|
148
|
+
model_name=None,
|
149
|
+
is_page_first_layout=True,
|
150
|
+
)
|
151
|
+
)
|
152
|
+
test_batch_operation(
|
153
|
+
HiCacheStorageConfig(
|
154
|
+
is_mla_model=True,
|
155
|
+
tp_rank=3,
|
156
|
+
tp_size=8,
|
157
|
+
model_name=None,
|
158
|
+
is_page_first_layout=True,
|
159
|
+
)
|
160
|
+
)
|
161
|
+
logger.info(f"✅ All tests passed")
|
@@ -60,8 +60,6 @@ class TreeNode:
|
|
60
60
|
self.last_access_time = time.monotonic()
|
61
61
|
|
62
62
|
self.hit_count = 0
|
63
|
-
# indicating the node is loading KV cache from host
|
64
|
-
self.loading = False
|
65
63
|
# store the host indices of KV cache
|
66
64
|
self.host_value = None
|
67
65
|
|
@@ -464,7 +462,7 @@ class SWARadixCache(BasePrefixCache):
|
|
464
462
|
self.req_to_token_pool.free(req.req_pool_idx)
|
465
463
|
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
466
464
|
|
467
|
-
def cache_unfinished_req(self, req: Req) -> None:
|
465
|
+
def cache_unfinished_req(self, req: Req, chunked=False) -> None:
|
468
466
|
"""Cache request when it is unfinished."""
|
469
467
|
if self.disable:
|
470
468
|
kv_indices = self.req_to_token_pool.req_to_token[
|