sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,151 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, List, Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from aibrix_kvcache import (
|
6
|
+
BaseKVCacheManager,
|
7
|
+
BlockHashes,
|
8
|
+
KVCacheBlockLayout,
|
9
|
+
KVCacheBlockSpec,
|
10
|
+
KVCacheConfig,
|
11
|
+
KVCacheTensorSpec,
|
12
|
+
ModelSpec,
|
13
|
+
)
|
14
|
+
from aibrix_kvcache.common.absl_logging import log_every_n_seconds
|
15
|
+
|
16
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
17
|
+
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
class AibrixKVCacheStorage(HiCacheStorage):
|
23
|
+
def __init__(self, storage_config: HiCacheStorageConfig, mem_pool: HostKVCache):
|
24
|
+
if storage_config is not None:
|
25
|
+
self.is_mla_backend = storage_config.is_mla_model
|
26
|
+
self.local_rank = storage_config.tp_rank
|
27
|
+
else:
|
28
|
+
self.is_mla_backend = False
|
29
|
+
self.local_rank = 0
|
30
|
+
kv_cache = mem_pool.device_pool
|
31
|
+
self.page_size = mem_pool.page_size
|
32
|
+
self.kv_cache_dtype = kv_cache.dtype
|
33
|
+
self.layer_num = kv_cache.layer_num
|
34
|
+
self.kv_head_ids = [
|
35
|
+
self.local_rank * kv_cache.head_num + i for i in range(kv_cache.head_num)
|
36
|
+
]
|
37
|
+
if not self.is_mla_backend:
|
38
|
+
self.layer_ids = range(
|
39
|
+
kv_cache.start_layer, kv_cache.end_layer
|
40
|
+
) # for pipeline parallel
|
41
|
+
|
42
|
+
self.block_spec = KVCacheBlockSpec(
|
43
|
+
block_ntokens=self.page_size,
|
44
|
+
block_dtype=self.kv_cache_dtype,
|
45
|
+
block_layout=KVCacheBlockLayout(KVCacheBlockLayout.NCLD),
|
46
|
+
tensor_spec=KVCacheTensorSpec(
|
47
|
+
heads=self.kv_head_ids,
|
48
|
+
layers=self.layer_ids,
|
49
|
+
head_size=kv_cache.head_dim,
|
50
|
+
),
|
51
|
+
)
|
52
|
+
logger.info(self.block_spec)
|
53
|
+
config = KVCacheConfig(
|
54
|
+
block_spec=self.block_spec, model_spec=ModelSpec(102400)
|
55
|
+
)
|
56
|
+
self.kv_cache_manager = BaseKVCacheManager(config)
|
57
|
+
else:
|
58
|
+
raise NotImplementedError(
|
59
|
+
"MLA is not supported by AibrixKVCacheStorage yet."
|
60
|
+
)
|
61
|
+
|
62
|
+
def _aibrix_kvcache_metrics_report(self):
|
63
|
+
self.kv_cache_manager.metrics.summary()
|
64
|
+
self.kv_cache_manager.metrics.reset()
|
65
|
+
|
66
|
+
def batch_get(
|
67
|
+
self,
|
68
|
+
keys: List[str],
|
69
|
+
target_locations: List[torch.Tensor],
|
70
|
+
target_sizes: Optional[Any] = None,
|
71
|
+
) -> List[torch.Tensor | None]:
|
72
|
+
block_hash = BlockHashes(keys, self.page_size)
|
73
|
+
status = self.kv_cache_manager.acquire(None, block_hash)
|
74
|
+
log_every_n_seconds(
|
75
|
+
logger, logging.INFO, self._aibrix_kvcache_metrics_report(), 1
|
76
|
+
)
|
77
|
+
if status.is_ok():
|
78
|
+
num_fetched_tokens, handle = status.value
|
79
|
+
kv_blocks = handle.to_tensors()
|
80
|
+
assert len(kv_blocks) == len(target_locations)
|
81
|
+
for i in range(len(kv_blocks)):
|
82
|
+
assert (
|
83
|
+
target_locations[i].nbytes == kv_blocks[i].nbytes
|
84
|
+
), f"{target_locations[i].nbytes}, {kv_blocks[i].nbytes}"
|
85
|
+
target_locations[i].copy_(kv_blocks[i].flatten())
|
86
|
+
handle.release()
|
87
|
+
return target_locations
|
88
|
+
|
89
|
+
return [None] * len(keys)
|
90
|
+
|
91
|
+
def get(
|
92
|
+
self,
|
93
|
+
key: str,
|
94
|
+
target_location: Optional[Any] = None,
|
95
|
+
target_size: Optional[Any] = None,
|
96
|
+
) -> torch.Tensor | None:
|
97
|
+
return self.batch_get([key], [target_location], [target_size])[0]
|
98
|
+
|
99
|
+
def batch_set(
|
100
|
+
self,
|
101
|
+
keys: List[str],
|
102
|
+
values: Optional[Any] = None,
|
103
|
+
target_locations: Optional[Any] = None,
|
104
|
+
target_sizes: Optional[Any] = None,
|
105
|
+
) -> bool:
|
106
|
+
block_hash = BlockHashes(keys, self.page_size)
|
107
|
+
status = self.kv_cache_manager.allocate_for(None, block_hash)
|
108
|
+
if not status.is_ok():
|
109
|
+
logger.warning(
|
110
|
+
f"aibrix_kvcache set allocate failed, error_code {status.error_code}"
|
111
|
+
)
|
112
|
+
return False
|
113
|
+
handle = status.value
|
114
|
+
tensors = handle.to_tensors()
|
115
|
+
if len(tensors) != len(values):
|
116
|
+
logger.warning("aibrix_kvcache set allocate not enough")
|
117
|
+
return False
|
118
|
+
for i in range(len(tensors)):
|
119
|
+
assert (
|
120
|
+
tensors[i].nbytes == values[i].nbytes
|
121
|
+
), f"{tensors[i].nbytes}, {values[i].nbytes}"
|
122
|
+
tensors[i].reshape(values[i].shape).copy_(values[i]).reshape(
|
123
|
+
tensors[i].shape
|
124
|
+
)
|
125
|
+
status = self.kv_cache_manager.put(None, block_hash, handle)
|
126
|
+
if not status.is_ok():
|
127
|
+
logger.info(
|
128
|
+
f"AIBrix KVCache Storage set failed, error_code {status.error_code}"
|
129
|
+
)
|
130
|
+
return False
|
131
|
+
completed = status.value
|
132
|
+
return completed == len(keys) * self.page_size
|
133
|
+
|
134
|
+
def set(
|
135
|
+
self,
|
136
|
+
key: str,
|
137
|
+
value: Optional[Any] = None,
|
138
|
+
target_location: Optional[Any] = None,
|
139
|
+
target_size: Optional[Any] = None,
|
140
|
+
) -> bool:
|
141
|
+
return self.batch_set([key], [value], [target_location], [target_size])
|
142
|
+
|
143
|
+
def batch_exists(self, keys: List[str]) -> int:
|
144
|
+
block_hash = BlockHashes(keys, self.page_size)
|
145
|
+
status = self.kv_cache_manager.exists(None, block_hash)
|
146
|
+
if status.is_ok():
|
147
|
+
return status.value // self.page_size
|
148
|
+
return 0
|
149
|
+
|
150
|
+
def exists(self, key: str) -> bool | dict:
|
151
|
+
return self.batch_exists([key]) > 0
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.distributed
|
6
|
+
from aibrix_kvcache import (
|
7
|
+
BaseKVCacheManager,
|
8
|
+
GroupAwareKVCacheManager,
|
9
|
+
KVCacheBlockLayout,
|
10
|
+
KVCacheBlockSpec,
|
11
|
+
KVCacheConfig,
|
12
|
+
KVCacheMetrics,
|
13
|
+
KVCacheTensorSpec,
|
14
|
+
ModelSpec,
|
15
|
+
TokenListView,
|
16
|
+
)
|
17
|
+
from aibrix_kvcache.common.absl_logging import getLogger, log_every_n_seconds, log_if
|
18
|
+
from aibrix_kvcache_storage import AibrixKVCacheStorage
|
19
|
+
from torch.distributed import Backend, ProcessGroup
|
20
|
+
|
21
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
22
|
+
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
23
|
+
from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
|
24
|
+
|
25
|
+
logging.basicConfig(
|
26
|
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
27
|
+
)
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
def setup():
|
33
|
+
os.environ["RANK"] = "0"
|
34
|
+
os.environ["WORLD_SIZE"] = "1"
|
35
|
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
36
|
+
os.environ["MASTER_PORT"] = "63886"
|
37
|
+
|
38
|
+
|
39
|
+
class AIBrixKVCacheStorageTest:
|
40
|
+
def test_with_page_size(self):
|
41
|
+
config = HiCacheStorageConfig(
|
42
|
+
tp_rank=0,
|
43
|
+
tp_size=1,
|
44
|
+
is_mla_model=False,
|
45
|
+
is_page_first_layout=True,
|
46
|
+
model_name="test",
|
47
|
+
)
|
48
|
+
for page_size in range(1, 3):
|
49
|
+
logger.info(f"page_size: {page_size}")
|
50
|
+
batch_size = 2
|
51
|
+
head_num = 1
|
52
|
+
layer_num = 64
|
53
|
+
head_dim = 128
|
54
|
+
kv_cache = MHATokenToKVPool(
|
55
|
+
1024,
|
56
|
+
page_size,
|
57
|
+
torch.float16,
|
58
|
+
head_num,
|
59
|
+
head_dim,
|
60
|
+
layer_num,
|
61
|
+
"cpu",
|
62
|
+
False,
|
63
|
+
0,
|
64
|
+
layer_num,
|
65
|
+
)
|
66
|
+
mem_pool = MHATokenToKVPoolHost(kv_cache, 2, 0, page_size, "layer_first")
|
67
|
+
query_length = batch_size * 2
|
68
|
+
partial = batch_size
|
69
|
+
self.aibrix_kvcache = AibrixKVCacheStorage(config, mem_pool)
|
70
|
+
target_shape = (2, layer_num, page_size, head_num, head_dim)
|
71
|
+
rand_tensor = [
|
72
|
+
torch.rand(target_shape, dtype=torch.float16)
|
73
|
+
for _ in range(query_length)
|
74
|
+
]
|
75
|
+
keys = ["hash" + str(i) for i in range(query_length)]
|
76
|
+
partial_keys = keys[batch_size:query_length]
|
77
|
+
assert self.aibrix_kvcache.batch_exists(keys) == 0
|
78
|
+
assert self.aibrix_kvcache.batch_set(keys, rand_tensor)
|
79
|
+
get_tensor = [
|
80
|
+
torch.rand(target_shape, dtype=torch.float16).flatten()
|
81
|
+
for _ in range(query_length)
|
82
|
+
]
|
83
|
+
self.aibrix_kvcache.batch_get(keys, get_tensor)
|
84
|
+
for i in range(query_length):
|
85
|
+
assert torch.equal(get_tensor[i], rand_tensor[i].flatten())
|
86
|
+
ret = self.aibrix_kvcache.batch_exists(keys)
|
87
|
+
assert self.aibrix_kvcache.batch_exists(keys) == query_length
|
88
|
+
assert self.aibrix_kvcache.batch_exists(partial_keys) == partial
|
89
|
+
partial_get_tensor = [
|
90
|
+
torch.rand(target_shape, dtype=torch.float16).flatten()
|
91
|
+
for _ in range(partial)
|
92
|
+
]
|
93
|
+
self.aibrix_kvcache.batch_get(partial_keys, partial_get_tensor)
|
94
|
+
for i in range(partial):
|
95
|
+
assert torch.equal(
|
96
|
+
partial_get_tensor[i], rand_tensor[i + partial].flatten()
|
97
|
+
)
|
98
|
+
log_every_n_seconds(
|
99
|
+
logger,
|
100
|
+
logging.INFO,
|
101
|
+
self.aibrix_kvcache.kv_cache_manager.metrics.summary(),
|
102
|
+
1,
|
103
|
+
)
|
104
|
+
|
105
|
+
|
106
|
+
if __name__ == "__main__":
|
107
|
+
setup()
|
108
|
+
test = AIBrixKVCacheStorageTest()
|
109
|
+
test.test_with_page_size()
|
@@ -0,0 +1,223 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to SGLang project
|
3
|
+
|
4
|
+
import importlib
|
5
|
+
import logging
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict
|
7
|
+
|
8
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
pass
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class StorageBackendFactory:
|
17
|
+
"""Factory for creating storage backend instances with support for dynamic loading."""
|
18
|
+
|
19
|
+
_registry: Dict[str, Dict[str, Any]] = {}
|
20
|
+
|
21
|
+
@staticmethod
|
22
|
+
def _load_backend_class(
|
23
|
+
module_path: str, class_name: str, backend_name: str
|
24
|
+
) -> type[HiCacheStorage]:
|
25
|
+
"""Load and validate a backend class from module path."""
|
26
|
+
try:
|
27
|
+
module = importlib.import_module(module_path)
|
28
|
+
backend_class = getattr(module, class_name)
|
29
|
+
if not issubclass(backend_class, HiCacheStorage):
|
30
|
+
raise TypeError(
|
31
|
+
f"Backend class {class_name} must inherit from HiCacheStorage"
|
32
|
+
)
|
33
|
+
return backend_class
|
34
|
+
except ImportError as e:
|
35
|
+
raise ImportError(
|
36
|
+
f"Failed to import backend '{backend_name}' from '{module_path}': {e}"
|
37
|
+
) from e
|
38
|
+
except AttributeError as e:
|
39
|
+
raise AttributeError(
|
40
|
+
f"Class '{class_name}' not found in module '{module_path}': {e}"
|
41
|
+
) from e
|
42
|
+
|
43
|
+
@classmethod
|
44
|
+
def register_backend(cls, name: str, module_path: str, class_name: str) -> None:
|
45
|
+
"""Register a storage backend with lazy loading.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
name: Backend identifier
|
49
|
+
module_path: Python module path containing the backend class
|
50
|
+
class_name: Name of the backend class
|
51
|
+
"""
|
52
|
+
if name in cls._registry:
|
53
|
+
logger.warning(f"Backend '{name}' is already registered, overwriting")
|
54
|
+
|
55
|
+
def loader() -> type[HiCacheStorage]:
|
56
|
+
"""Lazy loader function to import the backend class."""
|
57
|
+
return cls._load_backend_class(module_path, class_name, name)
|
58
|
+
|
59
|
+
cls._registry[name] = {
|
60
|
+
"loader": loader,
|
61
|
+
"module_path": module_path,
|
62
|
+
"class_name": class_name,
|
63
|
+
}
|
64
|
+
|
65
|
+
@classmethod
|
66
|
+
def create_backend(
|
67
|
+
cls,
|
68
|
+
backend_name: str,
|
69
|
+
storage_config: HiCacheStorageConfig,
|
70
|
+
mem_pool_host: Any,
|
71
|
+
**kwargs,
|
72
|
+
) -> HiCacheStorage:
|
73
|
+
"""Create a storage backend instance.
|
74
|
+
Args:
|
75
|
+
backend_name: Name of the backend to create
|
76
|
+
storage_config: Storage configuration
|
77
|
+
mem_pool_host: Memory pool host object
|
78
|
+
**kwargs: Additional arguments passed to external backends
|
79
|
+
Returns:
|
80
|
+
Initialized storage backend instance
|
81
|
+
Raises:
|
82
|
+
ValueError: If backend is not registered and cannot be dynamically loaded
|
83
|
+
ImportError: If backend module cannot be imported
|
84
|
+
Exception: If backend initialization fails
|
85
|
+
"""
|
86
|
+
# First check if backend is already registered
|
87
|
+
if backend_name in cls._registry:
|
88
|
+
registry_entry = cls._registry[backend_name]
|
89
|
+
backend_class = registry_entry["loader"]()
|
90
|
+
logger.info(
|
91
|
+
f"Creating storage backend '{backend_name}' "
|
92
|
+
f"({registry_entry['module_path']}.{registry_entry['class_name']})"
|
93
|
+
)
|
94
|
+
return cls._create_builtin_backend(
|
95
|
+
backend_name, backend_class, storage_config, mem_pool_host
|
96
|
+
)
|
97
|
+
|
98
|
+
# Try to dynamically load backend from extra_config
|
99
|
+
if backend_name == "dynamic" and storage_config.extra_config is not None:
|
100
|
+
backend_config = storage_config.extra_config
|
101
|
+
return cls._create_dynamic_backend(
|
102
|
+
backend_config, storage_config, mem_pool_host, **kwargs
|
103
|
+
)
|
104
|
+
|
105
|
+
# Backend not found
|
106
|
+
available_backends = list(cls._registry.keys())
|
107
|
+
|
108
|
+
raise ValueError(
|
109
|
+
f"Unknown storage backend '{backend_name}'. "
|
110
|
+
f"Registered backends: {available_backends}. "
|
111
|
+
)
|
112
|
+
|
113
|
+
@classmethod
|
114
|
+
def _create_dynamic_backend(
|
115
|
+
cls,
|
116
|
+
backend_config: Dict[str, Any],
|
117
|
+
storage_config: HiCacheStorageConfig,
|
118
|
+
mem_pool_host: Any,
|
119
|
+
**kwargs,
|
120
|
+
) -> HiCacheStorage:
|
121
|
+
"""Create a backend dynamically from configuration."""
|
122
|
+
required_fields = ["backend_name", "module_path", "class_name"]
|
123
|
+
for field in required_fields:
|
124
|
+
if field not in backend_config:
|
125
|
+
raise ValueError(
|
126
|
+
f"Missing required field '{field}' in backend config for 'dynamic' backend"
|
127
|
+
)
|
128
|
+
|
129
|
+
backend_name = backend_config["backend_name"]
|
130
|
+
module_path = backend_config["module_path"]
|
131
|
+
class_name = backend_config["class_name"]
|
132
|
+
|
133
|
+
try:
|
134
|
+
# Import the backend class
|
135
|
+
backend_class = cls._load_backend_class(
|
136
|
+
module_path, class_name, backend_name
|
137
|
+
)
|
138
|
+
|
139
|
+
logger.info(
|
140
|
+
f"Creating dynamic storage backend '{backend_name}' "
|
141
|
+
f"({module_path}.{class_name})"
|
142
|
+
)
|
143
|
+
|
144
|
+
# Create the backend instance with storage_config
|
145
|
+
return backend_class(storage_config, kwargs)
|
146
|
+
except Exception as e:
|
147
|
+
logger.error(
|
148
|
+
f"Failed to create dynamic storage backend '{backend_name}': {e}"
|
149
|
+
)
|
150
|
+
raise
|
151
|
+
|
152
|
+
@classmethod
|
153
|
+
def _create_builtin_backend(
|
154
|
+
cls,
|
155
|
+
backend_name: str,
|
156
|
+
backend_class: type[HiCacheStorage],
|
157
|
+
storage_config: HiCacheStorageConfig,
|
158
|
+
mem_pool_host: Any,
|
159
|
+
) -> HiCacheStorage:
|
160
|
+
"""Create built-in backend with original initialization logic."""
|
161
|
+
if backend_name == "file":
|
162
|
+
return backend_class(storage_config)
|
163
|
+
elif backend_name == "nixl":
|
164
|
+
return backend_class()
|
165
|
+
elif backend_name == "mooncake":
|
166
|
+
backend = backend_class(storage_config)
|
167
|
+
return backend
|
168
|
+
elif backend_name == "aibrix":
|
169
|
+
backend = backend_class(storage_config, mem_pool_host)
|
170
|
+
return backend
|
171
|
+
elif backend_name == "hf3fs":
|
172
|
+
# Calculate bytes_per_page based on memory pool layout
|
173
|
+
if mem_pool_host.layout == "page_first":
|
174
|
+
bytes_per_page = (
|
175
|
+
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
176
|
+
)
|
177
|
+
elif mem_pool_host.layout == "layer_first":
|
178
|
+
bytes_per_page = (
|
179
|
+
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
180
|
+
)
|
181
|
+
|
182
|
+
dtype = mem_pool_host.dtype
|
183
|
+
return backend_class.from_env_config(bytes_per_page, dtype, storage_config)
|
184
|
+
elif backend_name == "eic":
|
185
|
+
return backend_class(storage_config, mem_pool_host)
|
186
|
+
else:
|
187
|
+
raise ValueError(f"Unknown built-in backend: {backend_name}")
|
188
|
+
|
189
|
+
|
190
|
+
# Register built-in storage backends
|
191
|
+
StorageBackendFactory.register_backend(
|
192
|
+
"file", "sglang.srt.mem_cache.hicache_storage", "HiCacheFile"
|
193
|
+
)
|
194
|
+
|
195
|
+
StorageBackendFactory.register_backend(
|
196
|
+
"nixl",
|
197
|
+
"sglang.srt.mem_cache.storage.nixl.hicache_nixl",
|
198
|
+
"HiCacheNixl",
|
199
|
+
)
|
200
|
+
|
201
|
+
StorageBackendFactory.register_backend(
|
202
|
+
"mooncake",
|
203
|
+
"sglang.srt.mem_cache.storage.mooncake_store.mooncake_store",
|
204
|
+
"MooncakeStore",
|
205
|
+
)
|
206
|
+
|
207
|
+
StorageBackendFactory.register_backend(
|
208
|
+
"hf3fs",
|
209
|
+
"sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs",
|
210
|
+
"HiCacheHF3FS",
|
211
|
+
)
|
212
|
+
|
213
|
+
StorageBackendFactory.register_backend(
|
214
|
+
"aibrix",
|
215
|
+
"sglang.srt.mem_cache.storage.aibrix_kvcache.aibrix_kvcache_storage",
|
216
|
+
"AibrixKVCacheStorage",
|
217
|
+
)
|
218
|
+
|
219
|
+
StorageBackendFactory.register_backend(
|
220
|
+
"eic",
|
221
|
+
"sglang.srt.mem_cache.storage.eic.eic_storage",
|
222
|
+
"EICStorage",
|
223
|
+
)
|