sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -5,19 +5,16 @@ import logging
|
|
5
5
|
import os
|
6
6
|
import signal
|
7
7
|
import threading
|
8
|
+
import time
|
8
9
|
from abc import ABC, abstractmethod
|
9
10
|
from functools import wraps
|
10
11
|
from typing import Any, List, Optional, Tuple
|
11
12
|
|
12
13
|
import torch
|
13
14
|
|
14
|
-
from sglang.srt.
|
15
|
-
from sglang.srt.
|
16
|
-
|
17
|
-
is_dp_attention_enabled,
|
18
|
-
)
|
19
|
-
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
20
|
-
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
15
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
16
|
+
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
|
17
|
+
from sglang.srt.metrics.collector import StorageMetrics
|
21
18
|
|
22
19
|
logger = logging.getLogger(__name__)
|
23
20
|
|
@@ -117,7 +114,36 @@ def synchronized():
|
|
117
114
|
return _decorator
|
118
115
|
|
119
116
|
|
117
|
+
def create_hf3fs_client(
|
118
|
+
path: str, size: int, bytes_per_page: int, entries: int, use_mock: bool = False
|
119
|
+
) -> Hf3fsClient:
|
120
|
+
"""Factory function to create appropriate HF3FS client.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
path: File path for storage
|
124
|
+
size: Total size of storage file
|
125
|
+
bytes_per_page: Bytes per page
|
126
|
+
entries: Number of entries for batch operations
|
127
|
+
use_mock: Whether to use mock client instead of real usrbio client
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
"""
|
131
|
+
if use_mock:
|
132
|
+
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsMockClient
|
133
|
+
|
134
|
+
logger.info(f"[Rank Using Hf3fsMockClient for testing")
|
135
|
+
return Hf3fsMockClient(path, size, bytes_per_page, entries)
|
136
|
+
else:
|
137
|
+
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_usrbio_client import (
|
138
|
+
Hf3fsUsrBioClient,
|
139
|
+
)
|
140
|
+
|
141
|
+
return Hf3fsUsrBioClient(path, size, bytes_per_page, entries)
|
142
|
+
|
143
|
+
|
120
144
|
class HiCacheHF3FS(HiCacheStorage):
|
145
|
+
"""HiCache backend that stores KV cache pages in HF3FS files."""
|
146
|
+
|
121
147
|
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
|
122
148
|
|
123
149
|
def __init__(
|
@@ -130,18 +156,27 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
130
156
|
entries: int,
|
131
157
|
dtype: torch.dtype,
|
132
158
|
metadata_client: Hf3fsMetadataInterface,
|
159
|
+
is_mla_model: bool = False,
|
160
|
+
is_page_first_layout: bool = False,
|
161
|
+
use_mock_client: bool = False,
|
133
162
|
):
|
134
163
|
self.rank = rank
|
135
164
|
self.file_path = file_path
|
136
165
|
self.file_size = file_size
|
137
166
|
self.numjobs = numjobs
|
138
167
|
self.bytes_per_page = bytes_per_page
|
168
|
+
self.gb_per_page = bytes_per_page / (1 << 30)
|
139
169
|
self.entries = entries
|
140
170
|
self.dtype = dtype
|
141
171
|
self.metadata_client = metadata_client
|
142
|
-
|
172
|
+
self.is_mla_model = is_mla_model
|
173
|
+
self.is_page_first_layout = is_page_first_layout
|
143
174
|
self.numel = self.bytes_per_page // self.dtype.itemsize
|
144
175
|
self.num_pages = self.file_size // self.bytes_per_page
|
176
|
+
self.skip_backup = False
|
177
|
+
if self.is_mla_model and self.rank != 0:
|
178
|
+
self.skip_backup = True
|
179
|
+
self.rank = 0
|
145
180
|
|
146
181
|
logger.info(
|
147
182
|
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
@@ -152,8 +187,12 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
152
187
|
|
153
188
|
self.ac = AtomicCounter(self.numjobs)
|
154
189
|
self.clients = [
|
155
|
-
|
156
|
-
self.file_path,
|
190
|
+
create_hf3fs_client(
|
191
|
+
self.file_path,
|
192
|
+
self.file_size,
|
193
|
+
self.bytes_per_page,
|
194
|
+
self.entries,
|
195
|
+
use_mock_client,
|
157
196
|
)
|
158
197
|
for _ in range(numjobs)
|
159
198
|
]
|
@@ -170,24 +209,57 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
170
209
|
signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
|
171
210
|
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
|
172
211
|
|
212
|
+
self.prefetch_pgs = []
|
213
|
+
self.backup_pgs = []
|
214
|
+
self.prefetch_bandwidth = []
|
215
|
+
self.backup_bandwidth = []
|
216
|
+
|
173
217
|
@staticmethod
|
174
218
|
def from_env_config(
|
175
|
-
bytes_per_page: int,
|
219
|
+
bytes_per_page: int,
|
220
|
+
dtype: torch.dtype,
|
221
|
+
storage_config: HiCacheStorageConfig = None,
|
176
222
|
) -> "HiCacheHF3FS":
|
223
|
+
"""Create a HiCacheHF3FS instance from environment configuration.
|
224
|
+
|
225
|
+
Environment:
|
226
|
+
- Uses env var stored in `HiCacheHF3FS.default_env_var` to locate a JSON config.
|
227
|
+
- Falls back to a local single-machine config when the env var is not set.
|
228
|
+
|
229
|
+
Raises:
|
230
|
+
ValueError: If MLA Model is requested without global metadata server or required keys are missing.
|
231
|
+
"""
|
177
232
|
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
178
233
|
Hf3fsGlobalMetadataClient,
|
179
234
|
Hf3fsLocalMetadataClient,
|
180
235
|
)
|
181
236
|
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
237
|
+
use_mock_client = False
|
238
|
+
if storage_config is not None:
|
239
|
+
rank, is_mla_model, is_page_first_layout = (
|
240
|
+
storage_config.tp_rank,
|
241
|
+
storage_config.is_mla_model,
|
242
|
+
storage_config.is_page_first_layout,
|
187
243
|
)
|
188
244
|
|
245
|
+
if storage_config.extra_config is not None:
|
246
|
+
use_mock_client = storage_config.extra_config.get(
|
247
|
+
"use_mock_hf3fs_client", False
|
248
|
+
)
|
249
|
+
else:
|
250
|
+
rank, is_mla_model, is_page_first_layout = (
|
251
|
+
0,
|
252
|
+
False,
|
253
|
+
False,
|
254
|
+
)
|
255
|
+
|
256
|
+
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
|
257
|
+
|
189
258
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
190
259
|
if not config_path:
|
260
|
+
if is_mla_model:
|
261
|
+
raise ValueError(mla_unsupported_msg)
|
262
|
+
|
191
263
|
return HiCacheHF3FS(
|
192
264
|
rank=rank,
|
193
265
|
file_path=f"/data/hicache.{rank}.bin",
|
@@ -197,6 +269,8 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
197
269
|
entries=8,
|
198
270
|
dtype=dtype,
|
199
271
|
metadata_client=Hf3fsLocalMetadataClient(),
|
272
|
+
is_page_first_layout=is_page_first_layout,
|
273
|
+
use_mock_client=use_mock_client,
|
200
274
|
)
|
201
275
|
|
202
276
|
try:
|
@@ -217,26 +291,36 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
217
291
|
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
218
292
|
|
219
293
|
# Choose metadata client based on configuration
|
220
|
-
if
|
294
|
+
if config.get("metadata_server_url"):
|
221
295
|
# Use global metadata client to connect to metadata server
|
222
296
|
metadata_server_url = config["metadata_server_url"]
|
223
297
|
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
298
|
+
|
224
299
|
logger.info(
|
225
300
|
f"Using global metadata client with server url: {metadata_server_url}"
|
226
301
|
)
|
227
302
|
else:
|
303
|
+
# Enable MLA optimization only when using the global metadata client
|
304
|
+
if is_mla_model:
|
305
|
+
raise ValueError(mla_unsupported_msg)
|
306
|
+
|
228
307
|
# Use local metadata client for single-machine deployment
|
229
308
|
metadata_client = Hf3fsLocalMetadataClient()
|
230
309
|
|
310
|
+
rank_for_path = 0 if is_mla_model else rank
|
231
311
|
return HiCacheHF3FS(
|
232
312
|
rank=rank,
|
233
|
-
|
313
|
+
# Let all ranks use the same file path for MLA model
|
314
|
+
file_path=f"{config['file_path_prefix']}.{rank_for_path}.bin",
|
234
315
|
file_size=int(config["file_size"]),
|
235
316
|
numjobs=int(config["numjobs"]),
|
236
317
|
bytes_per_page=bytes_per_page,
|
237
318
|
entries=int(config["entries"]),
|
238
319
|
dtype=dtype,
|
239
320
|
metadata_client=metadata_client,
|
321
|
+
is_mla_model=is_mla_model,
|
322
|
+
is_page_first_layout=is_page_first_layout,
|
323
|
+
use_mock_client=use_mock_client,
|
240
324
|
)
|
241
325
|
|
242
326
|
def get(
|
@@ -276,6 +360,8 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
276
360
|
for _ in range(len(batch_indices))
|
277
361
|
]
|
278
362
|
|
363
|
+
start_time = time.perf_counter()
|
364
|
+
|
279
365
|
futures = [
|
280
366
|
self.executor.submit(
|
281
367
|
self.clients[self.ac.next()].batch_read,
|
@@ -286,6 +372,13 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
286
372
|
]
|
287
373
|
read_results = [result for future in futures for result in future.result()]
|
288
374
|
|
375
|
+
end_time = time.perf_counter()
|
376
|
+
ionum = len(batch_indices)
|
377
|
+
self.prefetch_pgs.append(ionum)
|
378
|
+
self.prefetch_bandwidth.append(
|
379
|
+
ionum / (end_time - start_time) * self.gb_per_page
|
380
|
+
)
|
381
|
+
|
289
382
|
results = [None] * len(keys)
|
290
383
|
for batch_index, file_result, read_result in zip(
|
291
384
|
batch_indices, file_results, read_results
|
@@ -313,6 +406,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
313
406
|
[target_sizes] if target_sizes is not None else None,
|
314
407
|
)
|
315
408
|
|
409
|
+
@synchronized()
|
316
410
|
def batch_set(
|
317
411
|
self,
|
318
412
|
keys: List[str],
|
@@ -320,6 +414,10 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
320
414
|
target_locations: Optional[Any] = None,
|
321
415
|
target_sizes: Optional[Any] = None,
|
322
416
|
) -> bool:
|
417
|
+
# In MLA backend, only one rank needs to backup the KV cache
|
418
|
+
if self.skip_backup:
|
419
|
+
return True
|
420
|
+
|
323
421
|
# Todo: Add prefix block's hash key
|
324
422
|
key_with_prefix = [(key, "") for key in keys]
|
325
423
|
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
@@ -338,6 +436,8 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
338
436
|
assert value.is_contiguous()
|
339
437
|
file_values.append(value)
|
340
438
|
|
439
|
+
start_time = time.perf_counter()
|
440
|
+
|
341
441
|
futures = [
|
342
442
|
self.executor.submit(
|
343
443
|
self.clients[self.ac.next()].batch_write,
|
@@ -352,6 +452,11 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
352
452
|
for result in future.result()
|
353
453
|
]
|
354
454
|
|
455
|
+
end_time = time.perf_counter()
|
456
|
+
ionum = len(batch_indices)
|
457
|
+
self.backup_pgs.append(ionum)
|
458
|
+
self.backup_bandwidth.append(ionum / (end_time - start_time) * self.gb_per_page)
|
459
|
+
|
355
460
|
written_keys_to_confirm = []
|
356
461
|
results = [index[0] for index in indices]
|
357
462
|
for batch_index, write_result in zip(batch_indices, write_results):
|
@@ -371,18 +476,29 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
371
476
|
|
372
477
|
return all(results)
|
373
478
|
|
374
|
-
@synchronized()
|
375
479
|
def delete(self, key: str) -> None:
|
376
480
|
self.metadata_client.delete_keys(self.rank, [key])
|
377
481
|
|
378
|
-
@synchronized()
|
379
482
|
def exists(self, key: str) -> bool:
|
380
483
|
result = self.metadata_client.exists(self.rank, [key])
|
381
484
|
return result[0] if result else False
|
382
485
|
|
383
|
-
|
384
|
-
|
385
|
-
|
486
|
+
def batch_exists(self, keys: List[str]) -> int:
|
487
|
+
results = self.metadata_client.exists(self.rank, keys)
|
488
|
+
for i in range(len(keys)):
|
489
|
+
if not results[i]:
|
490
|
+
return i
|
491
|
+
|
492
|
+
return len(keys)
|
493
|
+
|
494
|
+
def clear(self) -> bool:
|
495
|
+
try:
|
496
|
+
self.metadata_client.clear(self.rank)
|
497
|
+
logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
|
498
|
+
return True
|
499
|
+
except Exception as e:
|
500
|
+
logger.error(f"Failed to clear HiCacheHF3FS: {e}")
|
501
|
+
return False
|
386
502
|
|
387
503
|
def close(self) -> None:
|
388
504
|
try:
|
@@ -392,3 +508,16 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
392
508
|
except Exception as e:
|
393
509
|
logger.error(f"close HiCacheHF3FS: {e}")
|
394
510
|
logger.info("close HiCacheHF3FS")
|
511
|
+
|
512
|
+
@synchronized()
|
513
|
+
def get_stats(self):
|
514
|
+
storage_metrics = StorageMetrics()
|
515
|
+
storage_metrics.prefetch_pgs.extend(self.prefetch_pgs)
|
516
|
+
storage_metrics.backup_pgs.extend(self.backup_pgs)
|
517
|
+
storage_metrics.prefetch_bandwidth.extend(self.prefetch_bandwidth)
|
518
|
+
storage_metrics.backup_bandwidth.extend(self.backup_bandwidth)
|
519
|
+
self.prefetch_pgs.clear()
|
520
|
+
self.backup_pgs.clear()
|
521
|
+
self.prefetch_bandwidth.clear()
|
522
|
+
self.backup_bandwidth.clear()
|
523
|
+
return storage_metrics
|
@@ -0,0 +1,280 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import threading
|
5
|
+
from typing import TYPE_CHECKING, List, Optional
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
10
|
+
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
11
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
12
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
13
|
+
|
14
|
+
try:
|
15
|
+
from lmcache.integration.sglang.sglang_adapter import (
|
16
|
+
LMCacheLayerwiseConnector,
|
17
|
+
LoadMetadata,
|
18
|
+
StoreMetadata,
|
19
|
+
)
|
20
|
+
except ImportError as e:
|
21
|
+
raise RuntimeError(
|
22
|
+
"LMCache is not installed. Please install it by running `pip install lmcache`"
|
23
|
+
) from e
|
24
|
+
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from sglang.srt.configs.model_config import ModelConfig
|
27
|
+
from sglang.srt.managers.schedule_batch import Req
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
class LayerTransferCounter:
|
33
|
+
"""Minimal adapter that lets the memory pool notify LMCache per-layer.
|
34
|
+
|
35
|
+
The KV pool calls `wait_until(layer_id)` after finishing a layer, which we
|
36
|
+
translate into a `load_kv_layerwise(layer_id)` call on the LMCache connector
|
37
|
+
within the provided CUDA stream.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
num_layers: int,
|
43
|
+
load_stream: torch.cuda.Stream,
|
44
|
+
lmc_connector: LMCacheLayerwiseConnector,
|
45
|
+
printable: bool = False,
|
46
|
+
):
|
47
|
+
self.num_layers = num_layers
|
48
|
+
self.load_stream = load_stream
|
49
|
+
self.lmc_connector = lmc_connector
|
50
|
+
|
51
|
+
def wait_until(self, layer_id: int):
|
52
|
+
# Ensure ordering of the async loads wrt compute stream(s).
|
53
|
+
self.load_stream.synchronize()
|
54
|
+
with self.load_stream:
|
55
|
+
self.lmc_connector.load_kv_layerwise(layer_id)
|
56
|
+
|
57
|
+
|
58
|
+
class LMCRadixCache(RadixCache):
|
59
|
+
"""RadixCache + LMCache IO.
|
60
|
+
|
61
|
+
This subclass adds:
|
62
|
+
- LMCache connector setup (device/host buffers, TP rank/size)
|
63
|
+
- Two CUDA streams for async load/store
|
64
|
+
- Layer-wise transfer executor wiring to the KV cache
|
65
|
+
- Overridden `match_prefix` to fetch missing prefix chunks from LMCache
|
66
|
+
- Extended cache_finalization paths to store back into LMCache
|
67
|
+
- Eviction barrier that respects any in-flight host->device stores
|
68
|
+
"""
|
69
|
+
|
70
|
+
def __init__(
|
71
|
+
self,
|
72
|
+
req_to_token_pool: ReqToTokenPool,
|
73
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
74
|
+
page_size: int,
|
75
|
+
disable: bool = False,
|
76
|
+
enable_kv_cache_events: bool = False,
|
77
|
+
model_config: Optional["ModelConfig"] = None,
|
78
|
+
tp_size: int = 1,
|
79
|
+
rank: int = 0,
|
80
|
+
tp_group: Optional[torch.distributed.ProcessGroup] = None,
|
81
|
+
):
|
82
|
+
super().__init__(
|
83
|
+
req_to_token_pool=req_to_token_pool,
|
84
|
+
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
85
|
+
page_size=page_size,
|
86
|
+
disable=disable,
|
87
|
+
enable_kv_cache_events=enable_kv_cache_events,
|
88
|
+
)
|
89
|
+
|
90
|
+
kvcache = self.token_to_kv_pool_allocator.get_kvcache()
|
91
|
+
self.lmcache_connector = LMCacheLayerwiseConnector(
|
92
|
+
sgl_config=model_config,
|
93
|
+
tp_size=tp_size,
|
94
|
+
rank=rank,
|
95
|
+
# NOTE: The original implementation accessed private buffers via
|
96
|
+
# `_kvcache.k_buffer` / `.v_buffer`. We prefer public accessors when
|
97
|
+
# available; fall back to private fields if needed.
|
98
|
+
k_pool=getattr(
|
99
|
+
kvcache,
|
100
|
+
"k_buffer",
|
101
|
+
getattr(self.token_to_kv_pool_allocator._kvcache, "k_buffer"),
|
102
|
+
),
|
103
|
+
v_pool=getattr(
|
104
|
+
kvcache,
|
105
|
+
"v_buffer",
|
106
|
+
getattr(self.token_to_kv_pool_allocator._kvcache, "v_buffer"),
|
107
|
+
),
|
108
|
+
tp_group=tp_group,
|
109
|
+
)
|
110
|
+
|
111
|
+
self.load_stream = torch.cuda.Stream()
|
112
|
+
self.store_stream = torch.cuda.Stream()
|
113
|
+
|
114
|
+
self.layer_done_executor = LayerTransferCounter(
|
115
|
+
num_layers=(
|
116
|
+
model_config.num_hidden_layers if model_config is not None else 0
|
117
|
+
),
|
118
|
+
load_stream=self.load_stream,
|
119
|
+
lmc_connector=self.lmcache_connector,
|
120
|
+
)
|
121
|
+
kvcache.register_layer_transfer_counter(self.layer_done_executor)
|
122
|
+
|
123
|
+
self._in_flight_nodes: list[TreeNode] = []
|
124
|
+
self._node_lock = threading.Lock()
|
125
|
+
|
126
|
+
def reset(self): # type: ignore[override]
|
127
|
+
super().reset()
|
128
|
+
if hasattr(self, "_in_flight_nodes"):
|
129
|
+
with self._node_lock:
|
130
|
+
self._in_flight_nodes.clear()
|
131
|
+
|
132
|
+
def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override]
|
133
|
+
"""Match cached prefix; if there's a tail miss, prefetch from LMCache.
|
134
|
+
|
135
|
+
Reuses the base matching logic to obtain (value, last_node). If there
|
136
|
+
remains a *page-aligned* uncached suffix and there is room (or after
|
137
|
+
eviction), we allocate token slots and trigger an async LMCache load
|
138
|
+
into those slots, then materialize a new child node for the retrieved
|
139
|
+
chunk.
|
140
|
+
"""
|
141
|
+
if self.disable or not key:
|
142
|
+
return super().match_prefix(key, **kwargs)
|
143
|
+
|
144
|
+
if self.page_size != 1:
|
145
|
+
aligned_len = len(key) // self.page_size * self.page_size
|
146
|
+
key = key[:aligned_len]
|
147
|
+
|
148
|
+
base_res = super().match_prefix(key, **kwargs)
|
149
|
+
value: torch.Tensor = base_res.device_indices
|
150
|
+
last_node: TreeNode = base_res.last_device_node
|
151
|
+
|
152
|
+
if value.numel() == len(key):
|
153
|
+
return base_res
|
154
|
+
|
155
|
+
uncached_len = len(key) - value.numel()
|
156
|
+
if uncached_len == 0:
|
157
|
+
return base_res
|
158
|
+
|
159
|
+
chunk_size = self.lmcache_connector.chunk_size()
|
160
|
+
prefix_pad = value.numel() % chunk_size
|
161
|
+
|
162
|
+
if self.token_to_kv_pool_allocator.available_size() < uncached_len:
|
163
|
+
self.evict(uncached_len)
|
164
|
+
|
165
|
+
token_slots = self.token_to_kv_pool_allocator.alloc(uncached_len)
|
166
|
+
if token_slots is None:
|
167
|
+
return base_res
|
168
|
+
|
169
|
+
slot_mapping = torch.cat(
|
170
|
+
[
|
171
|
+
torch.full((value.numel(),), -1, dtype=torch.int64, device=self.device),
|
172
|
+
token_slots.detach().clone().to(torch.int64).to(self.device),
|
173
|
+
]
|
174
|
+
)
|
175
|
+
|
176
|
+
with torch.cuda.stream(self.load_stream):
|
177
|
+
num_retrieved = self.lmcache_connector.start_load_kv(
|
178
|
+
LoadMetadata(
|
179
|
+
token_ids=key, # full page-aligned key
|
180
|
+
slot_mapping=slot_mapping,
|
181
|
+
offset=value.numel() - prefix_pad, # LMCache offset convention
|
182
|
+
)
|
183
|
+
)
|
184
|
+
logger.debug("num_retrieved_tokens: %s", num_retrieved)
|
185
|
+
|
186
|
+
if num_retrieved > 0:
|
187
|
+
self.token_to_kv_pool_allocator.free(
|
188
|
+
token_slots[(num_retrieved - prefix_pad) :]
|
189
|
+
)
|
190
|
+
else:
|
191
|
+
self.token_to_kv_pool_allocator.free(token_slots)
|
192
|
+
|
193
|
+
if num_retrieved > 0:
|
194
|
+
fetched = num_retrieved - prefix_pad
|
195
|
+
new_node = TreeNode()
|
196
|
+
start = value.numel()
|
197
|
+
end = start + fetched
|
198
|
+
new_node.key = key[start:end]
|
199
|
+
new_node.value = token_slots[:fetched]
|
200
|
+
new_node.parent = last_node
|
201
|
+
last_node.children[self.get_child_key_fn(new_node.key)] = new_node
|
202
|
+
last_node = new_node
|
203
|
+
|
204
|
+
value = torch.cat([value, token_slots[:fetched]])
|
205
|
+
self.evictable_size_ += fetched
|
206
|
+
|
207
|
+
self._record_store_event(new_node.parent)
|
208
|
+
self._record_store_event(new_node)
|
209
|
+
|
210
|
+
return MatchResult(
|
211
|
+
device_indices=value,
|
212
|
+
last_device_node=last_node,
|
213
|
+
last_host_node=last_node,
|
214
|
+
)
|
215
|
+
|
216
|
+
return base_res
|
217
|
+
|
218
|
+
def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
|
219
|
+
"""On request completion, insert device KV into radix and store to LMCache."""
|
220
|
+
|
221
|
+
super().cache_finished_req(req)
|
222
|
+
|
223
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
224
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
225
|
+
req.req_pool_idx, : len(token_ids)
|
226
|
+
]
|
227
|
+
|
228
|
+
_, new_last_node, _, _ = self.match_prefix(token_ids)
|
229
|
+
assert new_last_node is not None
|
230
|
+
|
231
|
+
self.inc_lock_ref(new_last_node)
|
232
|
+
store_md = StoreMetadata(
|
233
|
+
last_node=new_last_node,
|
234
|
+
token_ids=token_ids,
|
235
|
+
kv_indices=kv_indices,
|
236
|
+
offset=0,
|
237
|
+
)
|
238
|
+
with torch.cuda.stream(self.store_stream):
|
239
|
+
self.lmcache_connector.store_kv(store_md)
|
240
|
+
with self._node_lock:
|
241
|
+
self._in_flight_nodes.append(new_last_node)
|
242
|
+
|
243
|
+
def evict(self, num_tokens: int) -> None: # type: ignore[override]
|
244
|
+
"""Before base eviction, wait for any outstanding stores and release locks."""
|
245
|
+
if self.disable:
|
246
|
+
return
|
247
|
+
|
248
|
+
self.store_stream.synchronize()
|
249
|
+
with self._node_lock:
|
250
|
+
for node in self._in_flight_nodes:
|
251
|
+
self.dec_lock_ref(node)
|
252
|
+
self._in_flight_nodes.clear()
|
253
|
+
|
254
|
+
super().evict(num_tokens)
|
255
|
+
|
256
|
+
def pretty_print(self): # type: ignore[override]
|
257
|
+
super().pretty_print()
|
258
|
+
try:
|
259
|
+
logger.debug(
|
260
|
+
"evictable=%d protected=%d", self.evictable_size_, self.protected_size_
|
261
|
+
)
|
262
|
+
except Exception: # pragma: no cover
|
263
|
+
pass
|
264
|
+
|
265
|
+
|
266
|
+
if __name__ == "__main__":
|
267
|
+
cache = LMCRadixCache(
|
268
|
+
req_to_token_pool=None,
|
269
|
+
token_to_kv_pool_allocator=None,
|
270
|
+
page_size=1,
|
271
|
+
disable=False,
|
272
|
+
enable_kv_cache_events=False,
|
273
|
+
model_config=None,
|
274
|
+
tp_size=1,
|
275
|
+
rank=0,
|
276
|
+
tp_group=None,
|
277
|
+
)
|
278
|
+
cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64))
|
279
|
+
cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64))
|
280
|
+
cache.pretty_print()
|