sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ import logging
|
|
5
5
|
import os
|
6
6
|
import signal
|
7
7
|
import threading
|
8
|
+
import time
|
8
9
|
from abc import ABC, abstractmethod
|
9
10
|
from functools import wraps
|
10
11
|
from typing import Any, List, Optional, Tuple
|
@@ -12,7 +13,8 @@ from typing import Any, List, Optional, Tuple
|
|
12
13
|
import torch
|
13
14
|
|
14
15
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
15
|
-
from sglang.srt.mem_cache.storage.hf3fs.
|
16
|
+
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
|
17
|
+
from sglang.srt.metrics.collector import StorageMetrics
|
16
18
|
|
17
19
|
logger = logging.getLogger(__name__)
|
18
20
|
|
@@ -112,7 +114,36 @@ def synchronized():
|
|
112
114
|
return _decorator
|
113
115
|
|
114
116
|
|
117
|
+
def create_hf3fs_client(
|
118
|
+
path: str, size: int, bytes_per_page: int, entries: int, use_mock: bool = False
|
119
|
+
) -> Hf3fsClient:
|
120
|
+
"""Factory function to create appropriate HF3FS client.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
path: File path for storage
|
124
|
+
size: Total size of storage file
|
125
|
+
bytes_per_page: Bytes per page
|
126
|
+
entries: Number of entries for batch operations
|
127
|
+
use_mock: Whether to use mock client instead of real usrbio client
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
"""
|
131
|
+
if use_mock:
|
132
|
+
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsMockClient
|
133
|
+
|
134
|
+
logger.info(f"[Rank Using Hf3fsMockClient for testing")
|
135
|
+
return Hf3fsMockClient(path, size, bytes_per_page, entries)
|
136
|
+
else:
|
137
|
+
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_usrbio_client import (
|
138
|
+
Hf3fsUsrBioClient,
|
139
|
+
)
|
140
|
+
|
141
|
+
return Hf3fsUsrBioClient(path, size, bytes_per_page, entries)
|
142
|
+
|
143
|
+
|
115
144
|
class HiCacheHF3FS(HiCacheStorage):
|
145
|
+
"""HiCache backend that stores KV cache pages in HF3FS files."""
|
146
|
+
|
116
147
|
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
|
117
148
|
|
118
149
|
def __init__(
|
@@ -125,18 +156,27 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
125
156
|
entries: int,
|
126
157
|
dtype: torch.dtype,
|
127
158
|
metadata_client: Hf3fsMetadataInterface,
|
159
|
+
is_mla_model: bool = False,
|
160
|
+
is_page_first_layout: bool = False,
|
161
|
+
use_mock_client: bool = False,
|
128
162
|
):
|
129
163
|
self.rank = rank
|
130
164
|
self.file_path = file_path
|
131
165
|
self.file_size = file_size
|
132
166
|
self.numjobs = numjobs
|
133
167
|
self.bytes_per_page = bytes_per_page
|
168
|
+
self.gb_per_page = bytes_per_page / (1 << 30)
|
134
169
|
self.entries = entries
|
135
170
|
self.dtype = dtype
|
136
171
|
self.metadata_client = metadata_client
|
137
|
-
|
172
|
+
self.is_mla_model = is_mla_model
|
173
|
+
self.is_page_first_layout = is_page_first_layout
|
138
174
|
self.numel = self.bytes_per_page // self.dtype.itemsize
|
139
175
|
self.num_pages = self.file_size // self.bytes_per_page
|
176
|
+
self.skip_backup = False
|
177
|
+
if self.is_mla_model and self.rank != 0:
|
178
|
+
self.skip_backup = True
|
179
|
+
self.rank = 0
|
140
180
|
|
141
181
|
logger.info(
|
142
182
|
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
@@ -147,8 +187,12 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
147
187
|
|
148
188
|
self.ac = AtomicCounter(self.numjobs)
|
149
189
|
self.clients = [
|
150
|
-
|
151
|
-
self.file_path,
|
190
|
+
create_hf3fs_client(
|
191
|
+
self.file_path,
|
192
|
+
self.file_size,
|
193
|
+
self.bytes_per_page,
|
194
|
+
self.entries,
|
195
|
+
use_mock_client,
|
152
196
|
)
|
153
197
|
for _ in range(numjobs)
|
154
198
|
]
|
@@ -165,21 +209,57 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
165
209
|
signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
|
166
210
|
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
|
167
211
|
|
212
|
+
self.prefetch_pgs = []
|
213
|
+
self.backup_pgs = []
|
214
|
+
self.prefetch_bandwidth = []
|
215
|
+
self.backup_bandwidth = []
|
216
|
+
|
168
217
|
@staticmethod
|
169
218
|
def from_env_config(
|
170
219
|
bytes_per_page: int,
|
171
220
|
dtype: torch.dtype,
|
172
221
|
storage_config: HiCacheStorageConfig = None,
|
173
222
|
) -> "HiCacheHF3FS":
|
223
|
+
"""Create a HiCacheHF3FS instance from environment configuration.
|
224
|
+
|
225
|
+
Environment:
|
226
|
+
- Uses env var stored in `HiCacheHF3FS.default_env_var` to locate a JSON config.
|
227
|
+
- Falls back to a local single-machine config when the env var is not set.
|
228
|
+
|
229
|
+
Raises:
|
230
|
+
ValueError: If MLA Model is requested without global metadata server or required keys are missing.
|
231
|
+
"""
|
174
232
|
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
175
233
|
Hf3fsGlobalMetadataClient,
|
176
234
|
Hf3fsLocalMetadataClient,
|
177
235
|
)
|
178
236
|
|
179
|
-
|
237
|
+
use_mock_client = False
|
238
|
+
if storage_config is not None:
|
239
|
+
rank, is_mla_model, is_page_first_layout = (
|
240
|
+
storage_config.tp_rank,
|
241
|
+
storage_config.is_mla_model,
|
242
|
+
storage_config.is_page_first_layout,
|
243
|
+
)
|
244
|
+
|
245
|
+
if storage_config.extra_config is not None:
|
246
|
+
use_mock_client = storage_config.extra_config.get(
|
247
|
+
"use_mock_hf3fs_client", False
|
248
|
+
)
|
249
|
+
else:
|
250
|
+
rank, is_mla_model, is_page_first_layout = (
|
251
|
+
0,
|
252
|
+
False,
|
253
|
+
False,
|
254
|
+
)
|
255
|
+
|
256
|
+
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
|
180
257
|
|
181
258
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
182
259
|
if not config_path:
|
260
|
+
if is_mla_model:
|
261
|
+
raise ValueError(mla_unsupported_msg)
|
262
|
+
|
183
263
|
return HiCacheHF3FS(
|
184
264
|
rank=rank,
|
185
265
|
file_path=f"/data/hicache.{rank}.bin",
|
@@ -189,6 +269,8 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
189
269
|
entries=8,
|
190
270
|
dtype=dtype,
|
191
271
|
metadata_client=Hf3fsLocalMetadataClient(),
|
272
|
+
is_page_first_layout=is_page_first_layout,
|
273
|
+
use_mock_client=use_mock_client,
|
192
274
|
)
|
193
275
|
|
194
276
|
try:
|
@@ -209,26 +291,36 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
209
291
|
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
210
292
|
|
211
293
|
# Choose metadata client based on configuration
|
212
|
-
if
|
294
|
+
if config.get("metadata_server_url"):
|
213
295
|
# Use global metadata client to connect to metadata server
|
214
296
|
metadata_server_url = config["metadata_server_url"]
|
215
297
|
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
298
|
+
|
216
299
|
logger.info(
|
217
300
|
f"Using global metadata client with server url: {metadata_server_url}"
|
218
301
|
)
|
219
302
|
else:
|
303
|
+
# Enable MLA optimization only when using the global metadata client
|
304
|
+
if is_mla_model:
|
305
|
+
raise ValueError(mla_unsupported_msg)
|
306
|
+
|
220
307
|
# Use local metadata client for single-machine deployment
|
221
308
|
metadata_client = Hf3fsLocalMetadataClient()
|
222
309
|
|
310
|
+
rank_for_path = 0 if is_mla_model else rank
|
223
311
|
return HiCacheHF3FS(
|
224
312
|
rank=rank,
|
225
|
-
|
313
|
+
# Let all ranks use the same file path for MLA model
|
314
|
+
file_path=f"{config['file_path_prefix']}.{rank_for_path}.bin",
|
226
315
|
file_size=int(config["file_size"]),
|
227
316
|
numjobs=int(config["numjobs"]),
|
228
317
|
bytes_per_page=bytes_per_page,
|
229
318
|
entries=int(config["entries"]),
|
230
319
|
dtype=dtype,
|
231
320
|
metadata_client=metadata_client,
|
321
|
+
is_mla_model=is_mla_model,
|
322
|
+
is_page_first_layout=is_page_first_layout,
|
323
|
+
use_mock_client=use_mock_client,
|
232
324
|
)
|
233
325
|
|
234
326
|
def get(
|
@@ -268,6 +360,8 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
268
360
|
for _ in range(len(batch_indices))
|
269
361
|
]
|
270
362
|
|
363
|
+
start_time = time.perf_counter()
|
364
|
+
|
271
365
|
futures = [
|
272
366
|
self.executor.submit(
|
273
367
|
self.clients[self.ac.next()].batch_read,
|
@@ -278,6 +372,13 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
278
372
|
]
|
279
373
|
read_results = [result for future in futures for result in future.result()]
|
280
374
|
|
375
|
+
end_time = time.perf_counter()
|
376
|
+
ionum = len(batch_indices)
|
377
|
+
self.prefetch_pgs.append(ionum)
|
378
|
+
self.prefetch_bandwidth.append(
|
379
|
+
ionum / (end_time - start_time) * self.gb_per_page
|
380
|
+
)
|
381
|
+
|
281
382
|
results = [None] * len(keys)
|
282
383
|
for batch_index, file_result, read_result in zip(
|
283
384
|
batch_indices, file_results, read_results
|
@@ -305,6 +406,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
305
406
|
[target_sizes] if target_sizes is not None else None,
|
306
407
|
)
|
307
408
|
|
409
|
+
@synchronized()
|
308
410
|
def batch_set(
|
309
411
|
self,
|
310
412
|
keys: List[str],
|
@@ -312,6 +414,10 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
312
414
|
target_locations: Optional[Any] = None,
|
313
415
|
target_sizes: Optional[Any] = None,
|
314
416
|
) -> bool:
|
417
|
+
# In MLA backend, only one rank needs to backup the KV cache
|
418
|
+
if self.skip_backup:
|
419
|
+
return True
|
420
|
+
|
315
421
|
# Todo: Add prefix block's hash key
|
316
422
|
key_with_prefix = [(key, "") for key in keys]
|
317
423
|
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
@@ -330,6 +436,8 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
330
436
|
assert value.is_contiguous()
|
331
437
|
file_values.append(value)
|
332
438
|
|
439
|
+
start_time = time.perf_counter()
|
440
|
+
|
333
441
|
futures = [
|
334
442
|
self.executor.submit(
|
335
443
|
self.clients[self.ac.next()].batch_write,
|
@@ -344,6 +452,11 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
344
452
|
for result in future.result()
|
345
453
|
]
|
346
454
|
|
455
|
+
end_time = time.perf_counter()
|
456
|
+
ionum = len(batch_indices)
|
457
|
+
self.backup_pgs.append(ionum)
|
458
|
+
self.backup_bandwidth.append(ionum / (end_time - start_time) * self.gb_per_page)
|
459
|
+
|
347
460
|
written_keys_to_confirm = []
|
348
461
|
results = [index[0] for index in indices]
|
349
462
|
for batch_index, write_result in zip(batch_indices, write_results):
|
@@ -363,18 +476,29 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
363
476
|
|
364
477
|
return all(results)
|
365
478
|
|
366
|
-
@synchronized()
|
367
479
|
def delete(self, key: str) -> None:
|
368
480
|
self.metadata_client.delete_keys(self.rank, [key])
|
369
481
|
|
370
|
-
@synchronized()
|
371
482
|
def exists(self, key: str) -> bool:
|
372
483
|
result = self.metadata_client.exists(self.rank, [key])
|
373
484
|
return result[0] if result else False
|
374
485
|
|
375
|
-
|
376
|
-
|
377
|
-
|
486
|
+
def batch_exists(self, keys: List[str]) -> int:
|
487
|
+
results = self.metadata_client.exists(self.rank, keys)
|
488
|
+
for i in range(len(keys)):
|
489
|
+
if not results[i]:
|
490
|
+
return i
|
491
|
+
|
492
|
+
return len(keys)
|
493
|
+
|
494
|
+
def clear(self) -> bool:
|
495
|
+
try:
|
496
|
+
self.metadata_client.clear(self.rank)
|
497
|
+
logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
|
498
|
+
return True
|
499
|
+
except Exception as e:
|
500
|
+
logger.error(f"Failed to clear HiCacheHF3FS: {e}")
|
501
|
+
return False
|
378
502
|
|
379
503
|
def close(self) -> None:
|
380
504
|
try:
|
@@ -384,3 +508,16 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
384
508
|
except Exception as e:
|
385
509
|
logger.error(f"close HiCacheHF3FS: {e}")
|
386
510
|
logger.info("close HiCacheHF3FS")
|
511
|
+
|
512
|
+
@synchronized()
|
513
|
+
def get_stats(self):
|
514
|
+
storage_metrics = StorageMetrics()
|
515
|
+
storage_metrics.prefetch_pgs.extend(self.prefetch_pgs)
|
516
|
+
storage_metrics.backup_pgs.extend(self.backup_pgs)
|
517
|
+
storage_metrics.prefetch_bandwidth.extend(self.prefetch_bandwidth)
|
518
|
+
storage_metrics.backup_bandwidth.extend(self.backup_bandwidth)
|
519
|
+
self.prefetch_pgs.clear()
|
520
|
+
self.backup_pgs.clear()
|
521
|
+
self.prefetch_bandwidth.clear()
|
522
|
+
self.backup_bandwidth.clear()
|
523
|
+
return storage_metrics
|
@@ -0,0 +1,280 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import threading
|
5
|
+
from typing import TYPE_CHECKING, List, Optional
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
10
|
+
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
11
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
12
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
13
|
+
|
14
|
+
try:
|
15
|
+
from lmcache.integration.sglang.sglang_adapter import (
|
16
|
+
LMCacheLayerwiseConnector,
|
17
|
+
LoadMetadata,
|
18
|
+
StoreMetadata,
|
19
|
+
)
|
20
|
+
except ImportError as e:
|
21
|
+
raise RuntimeError(
|
22
|
+
"LMCache is not installed. Please install it by running `pip install lmcache`"
|
23
|
+
) from e
|
24
|
+
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from sglang.srt.configs.model_config import ModelConfig
|
27
|
+
from sglang.srt.managers.schedule_batch import Req
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
class LayerTransferCounter:
|
33
|
+
"""Minimal adapter that lets the memory pool notify LMCache per-layer.
|
34
|
+
|
35
|
+
The KV pool calls `wait_until(layer_id)` after finishing a layer, which we
|
36
|
+
translate into a `load_kv_layerwise(layer_id)` call on the LMCache connector
|
37
|
+
within the provided CUDA stream.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
num_layers: int,
|
43
|
+
load_stream: torch.cuda.Stream,
|
44
|
+
lmc_connector: LMCacheLayerwiseConnector,
|
45
|
+
printable: bool = False,
|
46
|
+
):
|
47
|
+
self.num_layers = num_layers
|
48
|
+
self.load_stream = load_stream
|
49
|
+
self.lmc_connector = lmc_connector
|
50
|
+
|
51
|
+
def wait_until(self, layer_id: int):
|
52
|
+
# Ensure ordering of the async loads wrt compute stream(s).
|
53
|
+
self.load_stream.synchronize()
|
54
|
+
with self.load_stream:
|
55
|
+
self.lmc_connector.load_kv_layerwise(layer_id)
|
56
|
+
|
57
|
+
|
58
|
+
class LMCRadixCache(RadixCache):
|
59
|
+
"""RadixCache + LMCache IO.
|
60
|
+
|
61
|
+
This subclass adds:
|
62
|
+
- LMCache connector setup (device/host buffers, TP rank/size)
|
63
|
+
- Two CUDA streams for async load/store
|
64
|
+
- Layer-wise transfer executor wiring to the KV cache
|
65
|
+
- Overridden `match_prefix` to fetch missing prefix chunks from LMCache
|
66
|
+
- Extended cache_finalization paths to store back into LMCache
|
67
|
+
- Eviction barrier that respects any in-flight host->device stores
|
68
|
+
"""
|
69
|
+
|
70
|
+
def __init__(
|
71
|
+
self,
|
72
|
+
req_to_token_pool: ReqToTokenPool,
|
73
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
74
|
+
page_size: int,
|
75
|
+
disable: bool = False,
|
76
|
+
enable_kv_cache_events: bool = False,
|
77
|
+
model_config: Optional["ModelConfig"] = None,
|
78
|
+
tp_size: int = 1,
|
79
|
+
rank: int = 0,
|
80
|
+
tp_group: Optional[torch.distributed.ProcessGroup] = None,
|
81
|
+
):
|
82
|
+
super().__init__(
|
83
|
+
req_to_token_pool=req_to_token_pool,
|
84
|
+
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
85
|
+
page_size=page_size,
|
86
|
+
disable=disable,
|
87
|
+
enable_kv_cache_events=enable_kv_cache_events,
|
88
|
+
)
|
89
|
+
|
90
|
+
kvcache = self.token_to_kv_pool_allocator.get_kvcache()
|
91
|
+
self.lmcache_connector = LMCacheLayerwiseConnector(
|
92
|
+
sgl_config=model_config,
|
93
|
+
tp_size=tp_size,
|
94
|
+
rank=rank,
|
95
|
+
# NOTE: The original implementation accessed private buffers via
|
96
|
+
# `_kvcache.k_buffer` / `.v_buffer`. We prefer public accessors when
|
97
|
+
# available; fall back to private fields if needed.
|
98
|
+
k_pool=getattr(
|
99
|
+
kvcache,
|
100
|
+
"k_buffer",
|
101
|
+
getattr(self.token_to_kv_pool_allocator._kvcache, "k_buffer"),
|
102
|
+
),
|
103
|
+
v_pool=getattr(
|
104
|
+
kvcache,
|
105
|
+
"v_buffer",
|
106
|
+
getattr(self.token_to_kv_pool_allocator._kvcache, "v_buffer"),
|
107
|
+
),
|
108
|
+
tp_group=tp_group,
|
109
|
+
)
|
110
|
+
|
111
|
+
self.load_stream = torch.cuda.Stream()
|
112
|
+
self.store_stream = torch.cuda.Stream()
|
113
|
+
|
114
|
+
self.layer_done_executor = LayerTransferCounter(
|
115
|
+
num_layers=(
|
116
|
+
model_config.num_hidden_layers if model_config is not None else 0
|
117
|
+
),
|
118
|
+
load_stream=self.load_stream,
|
119
|
+
lmc_connector=self.lmcache_connector,
|
120
|
+
)
|
121
|
+
kvcache.register_layer_transfer_counter(self.layer_done_executor)
|
122
|
+
|
123
|
+
self._in_flight_nodes: list[TreeNode] = []
|
124
|
+
self._node_lock = threading.Lock()
|
125
|
+
|
126
|
+
def reset(self): # type: ignore[override]
|
127
|
+
super().reset()
|
128
|
+
if hasattr(self, "_in_flight_nodes"):
|
129
|
+
with self._node_lock:
|
130
|
+
self._in_flight_nodes.clear()
|
131
|
+
|
132
|
+
def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override]
|
133
|
+
"""Match cached prefix; if there's a tail miss, prefetch from LMCache.
|
134
|
+
|
135
|
+
Reuses the base matching logic to obtain (value, last_node). If there
|
136
|
+
remains a *page-aligned* uncached suffix and there is room (or after
|
137
|
+
eviction), we allocate token slots and trigger an async LMCache load
|
138
|
+
into those slots, then materialize a new child node for the retrieved
|
139
|
+
chunk.
|
140
|
+
"""
|
141
|
+
if self.disable or not key:
|
142
|
+
return super().match_prefix(key, **kwargs)
|
143
|
+
|
144
|
+
if self.page_size != 1:
|
145
|
+
aligned_len = len(key) // self.page_size * self.page_size
|
146
|
+
key = key[:aligned_len]
|
147
|
+
|
148
|
+
base_res = super().match_prefix(key, **kwargs)
|
149
|
+
value: torch.Tensor = base_res.device_indices
|
150
|
+
last_node: TreeNode = base_res.last_device_node
|
151
|
+
|
152
|
+
if value.numel() == len(key):
|
153
|
+
return base_res
|
154
|
+
|
155
|
+
uncached_len = len(key) - value.numel()
|
156
|
+
if uncached_len == 0:
|
157
|
+
return base_res
|
158
|
+
|
159
|
+
chunk_size = self.lmcache_connector.chunk_size()
|
160
|
+
prefix_pad = value.numel() % chunk_size
|
161
|
+
|
162
|
+
if self.token_to_kv_pool_allocator.available_size() < uncached_len:
|
163
|
+
self.evict(uncached_len)
|
164
|
+
|
165
|
+
token_slots = self.token_to_kv_pool_allocator.alloc(uncached_len)
|
166
|
+
if token_slots is None:
|
167
|
+
return base_res
|
168
|
+
|
169
|
+
slot_mapping = torch.cat(
|
170
|
+
[
|
171
|
+
torch.full((value.numel(),), -1, dtype=torch.int64, device=self.device),
|
172
|
+
token_slots.detach().clone().to(torch.int64).to(self.device),
|
173
|
+
]
|
174
|
+
)
|
175
|
+
|
176
|
+
with torch.cuda.stream(self.load_stream):
|
177
|
+
num_retrieved = self.lmcache_connector.start_load_kv(
|
178
|
+
LoadMetadata(
|
179
|
+
token_ids=key, # full page-aligned key
|
180
|
+
slot_mapping=slot_mapping,
|
181
|
+
offset=value.numel() - prefix_pad, # LMCache offset convention
|
182
|
+
)
|
183
|
+
)
|
184
|
+
logger.debug("num_retrieved_tokens: %s", num_retrieved)
|
185
|
+
|
186
|
+
if num_retrieved > 0:
|
187
|
+
self.token_to_kv_pool_allocator.free(
|
188
|
+
token_slots[(num_retrieved - prefix_pad) :]
|
189
|
+
)
|
190
|
+
else:
|
191
|
+
self.token_to_kv_pool_allocator.free(token_slots)
|
192
|
+
|
193
|
+
if num_retrieved > 0:
|
194
|
+
fetched = num_retrieved - prefix_pad
|
195
|
+
new_node = TreeNode()
|
196
|
+
start = value.numel()
|
197
|
+
end = start + fetched
|
198
|
+
new_node.key = key[start:end]
|
199
|
+
new_node.value = token_slots[:fetched]
|
200
|
+
new_node.parent = last_node
|
201
|
+
last_node.children[self.get_child_key_fn(new_node.key)] = new_node
|
202
|
+
last_node = new_node
|
203
|
+
|
204
|
+
value = torch.cat([value, token_slots[:fetched]])
|
205
|
+
self.evictable_size_ += fetched
|
206
|
+
|
207
|
+
self._record_store_event(new_node.parent)
|
208
|
+
self._record_store_event(new_node)
|
209
|
+
|
210
|
+
return MatchResult(
|
211
|
+
device_indices=value,
|
212
|
+
last_device_node=last_node,
|
213
|
+
last_host_node=last_node,
|
214
|
+
)
|
215
|
+
|
216
|
+
return base_res
|
217
|
+
|
218
|
+
def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
|
219
|
+
"""On request completion, insert device KV into radix and store to LMCache."""
|
220
|
+
|
221
|
+
super().cache_finished_req(req)
|
222
|
+
|
223
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
224
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
225
|
+
req.req_pool_idx, : len(token_ids)
|
226
|
+
]
|
227
|
+
|
228
|
+
_, new_last_node, _, _ = self.match_prefix(token_ids)
|
229
|
+
assert new_last_node is not None
|
230
|
+
|
231
|
+
self.inc_lock_ref(new_last_node)
|
232
|
+
store_md = StoreMetadata(
|
233
|
+
last_node=new_last_node,
|
234
|
+
token_ids=token_ids,
|
235
|
+
kv_indices=kv_indices,
|
236
|
+
offset=0,
|
237
|
+
)
|
238
|
+
with torch.cuda.stream(self.store_stream):
|
239
|
+
self.lmcache_connector.store_kv(store_md)
|
240
|
+
with self._node_lock:
|
241
|
+
self._in_flight_nodes.append(new_last_node)
|
242
|
+
|
243
|
+
def evict(self, num_tokens: int) -> None: # type: ignore[override]
|
244
|
+
"""Before base eviction, wait for any outstanding stores and release locks."""
|
245
|
+
if self.disable:
|
246
|
+
return
|
247
|
+
|
248
|
+
self.store_stream.synchronize()
|
249
|
+
with self._node_lock:
|
250
|
+
for node in self._in_flight_nodes:
|
251
|
+
self.dec_lock_ref(node)
|
252
|
+
self._in_flight_nodes.clear()
|
253
|
+
|
254
|
+
super().evict(num_tokens)
|
255
|
+
|
256
|
+
def pretty_print(self): # type: ignore[override]
|
257
|
+
super().pretty_print()
|
258
|
+
try:
|
259
|
+
logger.debug(
|
260
|
+
"evictable=%d protected=%d", self.evictable_size_, self.protected_size_
|
261
|
+
)
|
262
|
+
except Exception: # pragma: no cover
|
263
|
+
pass
|
264
|
+
|
265
|
+
|
266
|
+
if __name__ == "__main__":
|
267
|
+
cache = LMCRadixCache(
|
268
|
+
req_to_token_pool=None,
|
269
|
+
token_to_kv_pool_allocator=None,
|
270
|
+
page_size=1,
|
271
|
+
disable=False,
|
272
|
+
enable_kv_cache_events=False,
|
273
|
+
model_config=None,
|
274
|
+
tp_size=1,
|
275
|
+
rank=0,
|
276
|
+
tp_group=None,
|
277
|
+
)
|
278
|
+
cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64))
|
279
|
+
cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64))
|
280
|
+
cache.pretty_print()
|