sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -3,22 +3,26 @@ import logging
|
|
3
3
|
import threading
|
4
4
|
from enum import IntEnum
|
5
5
|
from functools import wraps
|
6
|
+
from typing import Optional
|
6
7
|
|
7
8
|
import psutil
|
8
9
|
import torch
|
9
10
|
|
10
11
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
11
|
-
from sglang.srt.utils import is_npu
|
12
|
+
from sglang.srt.utils import is_npu, is_xpu
|
12
13
|
|
13
14
|
_is_npu = is_npu()
|
14
|
-
|
15
|
+
_is_xpu = is_xpu()
|
16
|
+
if not (_is_npu or _is_xpu):
|
15
17
|
from sgl_kernel.kvcacheio import (
|
16
18
|
transfer_kv_all_layer,
|
19
|
+
transfer_kv_all_layer_direct_lf_pf,
|
17
20
|
transfer_kv_all_layer_lf_pf,
|
18
21
|
transfer_kv_all_layer_mla,
|
19
22
|
transfer_kv_all_layer_mla_lf_pf,
|
20
23
|
transfer_kv_direct,
|
21
24
|
transfer_kv_per_layer,
|
25
|
+
transfer_kv_per_layer_direct_pf_lf,
|
22
26
|
transfer_kv_per_layer_mla,
|
23
27
|
transfer_kv_per_layer_mla_pf_lf,
|
24
28
|
transfer_kv_per_layer_pf_lf,
|
@@ -76,6 +80,7 @@ class HostKVCache(abc.ABC):
|
|
76
80
|
self.size = int(device_pool.size * host_to_device_ratio)
|
77
81
|
# Align the host memory pool size to the page size
|
78
82
|
self.size = self.size - (self.size % self.page_size)
|
83
|
+
self.page_num = self.size // self.page_size
|
79
84
|
self.start_layer = device_pool.start_layer
|
80
85
|
self.end_layer = device_pool.end_layer
|
81
86
|
|
@@ -168,7 +173,7 @@ class HostKVCache(abc.ABC):
|
|
168
173
|
return len(self.free_slots)
|
169
174
|
|
170
175
|
@synchronized()
|
171
|
-
def alloc(self, need_size: int) -> torch.Tensor:
|
176
|
+
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
|
172
177
|
assert (
|
173
178
|
need_size % self.page_size == 0
|
174
179
|
), "The requested size should be a multiple of the page size."
|
@@ -315,6 +320,15 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
315
320
|
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
316
321
|
elif self.layout == "page_first":
|
317
322
|
dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
|
323
|
+
elif self.layout == "page_first_direct":
|
324
|
+
dims = (
|
325
|
+
2,
|
326
|
+
self.page_num,
|
327
|
+
self.layer_num,
|
328
|
+
self.page_size,
|
329
|
+
self.head_num,
|
330
|
+
self.head_dim,
|
331
|
+
)
|
318
332
|
else:
|
319
333
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
320
334
|
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
|
@@ -368,19 +382,31 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
368
382
|
else:
|
369
383
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
370
384
|
elif io_backend == "direct":
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
385
|
+
if self.layout == "layer_first":
|
386
|
+
transfer_kv_direct(
|
387
|
+
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
|
388
|
+
dst_layers=[
|
389
|
+
device_pool.k_buffer[layer_id],
|
390
|
+
device_pool.v_buffer[layer_id],
|
391
|
+
],
|
392
|
+
src_indices=host_indices,
|
393
|
+
dst_indices=device_indices,
|
394
|
+
page_size=self.page_size,
|
395
|
+
)
|
396
|
+
elif self.layout == "page_first_direct":
|
397
|
+
transfer_kv_per_layer_direct_pf_lf(
|
398
|
+
src_ptrs=[self.k_buffer, self.v_buffer],
|
399
|
+
dst_ptrs=[
|
400
|
+
device_pool.k_buffer[layer_id],
|
401
|
+
device_pool.v_buffer[layer_id],
|
402
|
+
],
|
403
|
+
src_indices=host_indices,
|
404
|
+
dst_indices=device_indices,
|
405
|
+
layer_id=layer_id,
|
406
|
+
page_size=self.page_size,
|
407
|
+
)
|
408
|
+
else:
|
409
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
384
410
|
else:
|
385
411
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
386
412
|
|
@@ -414,16 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
414
440
|
else:
|
415
441
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
416
442
|
elif io_backend == "direct":
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
443
|
+
if self.layout == "layer_first":
|
444
|
+
transfer_kv_direct(
|
445
|
+
src_layers=device_pool.k_buffer + device_pool.v_buffer,
|
446
|
+
dst_layers=self.k_data_refs + self.v_data_refs,
|
447
|
+
src_indices=device_indices,
|
448
|
+
dst_indices=host_indices,
|
449
|
+
page_size=self.page_size,
|
450
|
+
)
|
451
|
+
elif self.layout == "page_first_direct":
|
452
|
+
transfer_kv_all_layer_direct_lf_pf(
|
453
|
+
src_ptrs=device_pool.k_buffer + device_pool.v_buffer,
|
454
|
+
dst_ptrs=[self.k_buffer, self.v_buffer],
|
455
|
+
src_indices=device_indices,
|
456
|
+
dst_indices=host_indices,
|
457
|
+
page_size=self.page_size,
|
458
|
+
)
|
459
|
+
else:
|
460
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
427
461
|
else:
|
428
462
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
429
463
|
|
@@ -578,6 +612,14 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
578
612
|
1,
|
579
613
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
580
614
|
)
|
615
|
+
elif self.layout == "page_first_direct":
|
616
|
+
dims = (
|
617
|
+
self.page_num,
|
618
|
+
self.layer_num,
|
619
|
+
self.page_size,
|
620
|
+
1,
|
621
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
622
|
+
)
|
581
623
|
else:
|
582
624
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
583
625
|
self.token_stride_size = (
|
@@ -617,16 +659,25 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
617
659
|
else:
|
618
660
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
619
661
|
elif io_backend == "direct":
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
662
|
+
if self.layout == "layer_first":
|
663
|
+
transfer_kv_direct(
|
664
|
+
src_layers=[self.kv_buffer[layer_id]],
|
665
|
+
dst_layers=[device_pool.kv_buffer[layer_id]],
|
666
|
+
src_indices=host_indices,
|
667
|
+
dst_indices=device_indices,
|
668
|
+
page_size=self.page_size,
|
669
|
+
)
|
670
|
+
elif self.layout == "page_first_direct":
|
671
|
+
transfer_kv_per_layer_direct_pf_lf(
|
672
|
+
src_ptrs=[self.kv_buffer],
|
673
|
+
dst_ptrs=[device_pool.kv_buffer[layer_id]],
|
674
|
+
src_indices=host_indices,
|
675
|
+
dst_indices=device_indices,
|
676
|
+
layer_id=layer_id,
|
677
|
+
page_size=self.page_size,
|
678
|
+
)
|
679
|
+
else:
|
680
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
630
681
|
|
631
682
|
def backup_from_device_all_layer(
|
632
683
|
self, device_pool, host_indices, device_indices, io_backend
|
@@ -654,16 +705,24 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
654
705
|
else:
|
655
706
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
656
707
|
elif io_backend == "direct":
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
708
|
+
if self.layout == "layer_first":
|
709
|
+
transfer_kv_direct(
|
710
|
+
src_layers=device_pool.kv_buffer,
|
711
|
+
dst_layers=self.data_refs,
|
712
|
+
src_indices=device_indices,
|
713
|
+
dst_indices=host_indices,
|
714
|
+
page_size=self.page_size,
|
715
|
+
)
|
716
|
+
elif self.layout == "page_first_direct":
|
717
|
+
transfer_kv_all_layer_direct_lf_pf(
|
718
|
+
src_ptrs=device_pool.kv_buffer,
|
719
|
+
dst_ptrs=[self.kv_buffer],
|
720
|
+
src_indices=device_indices,
|
721
|
+
dst_indices=host_indices,
|
722
|
+
page_size=self.page_size,
|
723
|
+
)
|
724
|
+
else:
|
725
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
667
726
|
else:
|
668
727
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
669
728
|
|
@@ -53,8 +53,6 @@ class TreeNode:
|
|
53
53
|
self.last_access_time = time.monotonic()
|
54
54
|
|
55
55
|
self.hit_count = 0
|
56
|
-
# indicating the node is loading KV cache from host
|
57
|
-
self.loading = False
|
58
56
|
# indicating the node is locked to protect from eviction
|
59
57
|
# incremented when the node is referenced by a storage operation
|
60
58
|
self.host_ref_counter = 0
|
@@ -0,0 +1,164 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import threading
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
class Hf3fsClient(ABC):
|
11
|
+
"""Abstract interface for HF3FS clients."""
|
12
|
+
|
13
|
+
@abstractmethod
|
14
|
+
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
15
|
+
"""Initialize the HF3FS client.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
path: File path for storage
|
19
|
+
size: Total size of storage file
|
20
|
+
bytes_per_page: Bytes per page
|
21
|
+
entries: Number of entries for batch operations
|
22
|
+
"""
|
23
|
+
pass
|
24
|
+
|
25
|
+
@abstractmethod
|
26
|
+
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
27
|
+
"""Batch read from storage."""
|
28
|
+
pass
|
29
|
+
|
30
|
+
@abstractmethod
|
31
|
+
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
32
|
+
"""Batch write to storage."""
|
33
|
+
pass
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
|
37
|
+
"""Validate batch operation parameters."""
|
38
|
+
pass
|
39
|
+
|
40
|
+
@abstractmethod
|
41
|
+
def get_size(self) -> int:
|
42
|
+
"""Get total storage size."""
|
43
|
+
pass
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def close(self) -> None:
|
47
|
+
"""Close the client and cleanup resources."""
|
48
|
+
pass
|
49
|
+
|
50
|
+
@abstractmethod
|
51
|
+
def flush(self) -> None:
|
52
|
+
"""Flush data to disk."""
|
53
|
+
pass
|
54
|
+
|
55
|
+
|
56
|
+
logger = logging.getLogger(__name__)
|
57
|
+
|
58
|
+
|
59
|
+
class Hf3fsMockClient(Hf3fsClient):
|
60
|
+
"""Mock implementation of Hf3fsClient for CI testing purposes."""
|
61
|
+
|
62
|
+
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
63
|
+
"""Initialize mock HF3FS client."""
|
64
|
+
self.path = path
|
65
|
+
self.size = size
|
66
|
+
self.bytes_per_page = bytes_per_page
|
67
|
+
self.entries = entries
|
68
|
+
|
69
|
+
# Create directory if it doesn't exist
|
70
|
+
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
71
|
+
|
72
|
+
# Create and initialize the file
|
73
|
+
self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
|
74
|
+
os.ftruncate(self.file, size)
|
75
|
+
|
76
|
+
logger.info(
|
77
|
+
f"Hf3fsMockClient initialized: path={path}, size={size}, "
|
78
|
+
f"bytes_per_page={bytes_per_page}, entries={entries}"
|
79
|
+
)
|
80
|
+
|
81
|
+
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
82
|
+
"""Batch read from mock storage."""
|
83
|
+
self.check(offsets, tensors)
|
84
|
+
|
85
|
+
results = []
|
86
|
+
|
87
|
+
for offset, tensor in zip(offsets, tensors):
|
88
|
+
size = tensor.numel() * tensor.itemsize
|
89
|
+
|
90
|
+
try:
|
91
|
+
os.lseek(self.file, offset, os.SEEK_SET)
|
92
|
+
bytes_read = os.read(self.file, size)
|
93
|
+
|
94
|
+
if len(bytes_read) == size:
|
95
|
+
# Convert bytes to tensor and copy to target
|
96
|
+
bytes_tensor = torch.frombuffer(bytes_read, dtype=torch.uint8)
|
97
|
+
typed_tensor = bytes_tensor.view(tensor.dtype).view(tensor.shape)
|
98
|
+
tensor.copy_(typed_tensor)
|
99
|
+
results.append(size)
|
100
|
+
else:
|
101
|
+
logger.warning(
|
102
|
+
f"Short read: expected {size}, got {len(bytes_read)}"
|
103
|
+
)
|
104
|
+
results.append(len(bytes_read))
|
105
|
+
|
106
|
+
except Exception as e:
|
107
|
+
logger.error(f"Error reading from offset {offset}: {e}")
|
108
|
+
results.append(0)
|
109
|
+
|
110
|
+
return results
|
111
|
+
|
112
|
+
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
113
|
+
"""Batch write to mock storage."""
|
114
|
+
self.check(offsets, tensors)
|
115
|
+
|
116
|
+
results = []
|
117
|
+
|
118
|
+
for offset, tensor in zip(offsets, tensors):
|
119
|
+
size = tensor.numel() * tensor.itemsize
|
120
|
+
|
121
|
+
try:
|
122
|
+
# Convert tensor to bytes and write directly to file
|
123
|
+
tensor_bytes = tensor.contiguous().view(torch.uint8).flatten()
|
124
|
+
data = tensor_bytes.numpy().tobytes()
|
125
|
+
|
126
|
+
os.lseek(self.file, offset, os.SEEK_SET)
|
127
|
+
bytes_written = os.write(self.file, data)
|
128
|
+
|
129
|
+
if bytes_written == size:
|
130
|
+
results.append(size)
|
131
|
+
else:
|
132
|
+
logger.warning(f"Short write: expected {size}, got {bytes_written}")
|
133
|
+
results.append(bytes_written)
|
134
|
+
|
135
|
+
except Exception as e:
|
136
|
+
logger.error(f"Error writing to offset {offset}: {e}")
|
137
|
+
results.append(0)
|
138
|
+
|
139
|
+
return results
|
140
|
+
|
141
|
+
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
|
142
|
+
"""Validate batch operation parameters."""
|
143
|
+
pass
|
144
|
+
|
145
|
+
def get_size(self) -> int:
|
146
|
+
"""Get total storage size."""
|
147
|
+
return self.size
|
148
|
+
|
149
|
+
def close(self) -> None:
|
150
|
+
"""Close the mock client and cleanup resources."""
|
151
|
+
try:
|
152
|
+
if hasattr(self, "file") and self.file >= 0:
|
153
|
+
os.close(self.file)
|
154
|
+
self.file = -1 # Mark as closed
|
155
|
+
logger.info(f"MockHf3fsClient closed: {self.path}")
|
156
|
+
except Exception as e:
|
157
|
+
logger.error(f"Error closing MockHf3fsClient: {e}")
|
158
|
+
|
159
|
+
def flush(self) -> None:
|
160
|
+
"""Flush data to disk."""
|
161
|
+
try:
|
162
|
+
os.fsync(self.file)
|
163
|
+
except Exception as e:
|
164
|
+
logger.error(f"Error flushing MockHf3fsClient: {e}")
|
@@ -9,6 +9,8 @@ from typing import List
|
|
9
9
|
import torch
|
10
10
|
from torch.utils.cpp_extension import load
|
11
11
|
|
12
|
+
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
|
13
|
+
|
12
14
|
root = Path(__file__).parent.resolve()
|
13
15
|
hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
|
14
16
|
|
@@ -51,7 +53,9 @@ def wsynchronized():
|
|
51
53
|
return _decorator
|
52
54
|
|
53
55
|
|
54
|
-
class Hf3fsClient:
|
56
|
+
class Hf3fsUsrBioClient(Hf3fsClient):
|
57
|
+
"""HF3FS client implementation using usrbio."""
|
58
|
+
|
55
59
|
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
56
60
|
if not HF3FS_AVAILABLE:
|
57
61
|
raise ImportError(
|
@@ -5,6 +5,7 @@ import logging
|
|
5
5
|
import os
|
6
6
|
import signal
|
7
7
|
import threading
|
8
|
+
import time
|
8
9
|
from abc import ABC, abstractmethod
|
9
10
|
from functools import wraps
|
10
11
|
from typing import Any, List, Optional, Tuple
|
@@ -12,7 +13,8 @@ from typing import Any, List, Optional, Tuple
|
|
12
13
|
import torch
|
13
14
|
|
14
15
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
15
|
-
from sglang.srt.mem_cache.storage.hf3fs.
|
16
|
+
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
|
17
|
+
from sglang.srt.metrics.collector import StorageMetrics
|
16
18
|
|
17
19
|
logger = logging.getLogger(__name__)
|
18
20
|
|
@@ -112,6 +114,33 @@ def synchronized():
|
|
112
114
|
return _decorator
|
113
115
|
|
114
116
|
|
117
|
+
def create_hf3fs_client(
|
118
|
+
path: str, size: int, bytes_per_page: int, entries: int, use_mock: bool = False
|
119
|
+
) -> Hf3fsClient:
|
120
|
+
"""Factory function to create appropriate HF3FS client.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
path: File path for storage
|
124
|
+
size: Total size of storage file
|
125
|
+
bytes_per_page: Bytes per page
|
126
|
+
entries: Number of entries for batch operations
|
127
|
+
use_mock: Whether to use mock client instead of real usrbio client
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
"""
|
131
|
+
if use_mock:
|
132
|
+
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsMockClient
|
133
|
+
|
134
|
+
logger.info(f"[Rank Using Hf3fsMockClient for testing")
|
135
|
+
return Hf3fsMockClient(path, size, bytes_per_page, entries)
|
136
|
+
else:
|
137
|
+
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_usrbio_client import (
|
138
|
+
Hf3fsUsrBioClient,
|
139
|
+
)
|
140
|
+
|
141
|
+
return Hf3fsUsrBioClient(path, size, bytes_per_page, entries)
|
142
|
+
|
143
|
+
|
115
144
|
class HiCacheHF3FS(HiCacheStorage):
|
116
145
|
"""HiCache backend that stores KV cache pages in HF3FS files."""
|
117
146
|
|
@@ -129,12 +158,14 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
129
158
|
metadata_client: Hf3fsMetadataInterface,
|
130
159
|
is_mla_model: bool = False,
|
131
160
|
is_page_first_layout: bool = False,
|
161
|
+
use_mock_client: bool = False,
|
132
162
|
):
|
133
163
|
self.rank = rank
|
134
164
|
self.file_path = file_path
|
135
165
|
self.file_size = file_size
|
136
166
|
self.numjobs = numjobs
|
137
167
|
self.bytes_per_page = bytes_per_page
|
168
|
+
self.gb_per_page = bytes_per_page / (1 << 30)
|
138
169
|
self.entries = entries
|
139
170
|
self.dtype = dtype
|
140
171
|
self.metadata_client = metadata_client
|
@@ -156,8 +187,12 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
156
187
|
|
157
188
|
self.ac = AtomicCounter(self.numjobs)
|
158
189
|
self.clients = [
|
159
|
-
|
160
|
-
self.file_path,
|
190
|
+
create_hf3fs_client(
|
191
|
+
self.file_path,
|
192
|
+
self.file_size,
|
193
|
+
self.bytes_per_page,
|
194
|
+
self.entries,
|
195
|
+
use_mock_client,
|
161
196
|
)
|
162
197
|
for _ in range(numjobs)
|
163
198
|
]
|
@@ -174,6 +209,11 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
174
209
|
signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
|
175
210
|
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
|
176
211
|
|
212
|
+
self.prefetch_pgs = []
|
213
|
+
self.backup_pgs = []
|
214
|
+
self.prefetch_bandwidth = []
|
215
|
+
self.backup_bandwidth = []
|
216
|
+
|
177
217
|
@staticmethod
|
178
218
|
def from_env_config(
|
179
219
|
bytes_per_page: int,
|
@@ -194,14 +234,24 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
194
234
|
Hf3fsLocalMetadataClient,
|
195
235
|
)
|
196
236
|
|
237
|
+
use_mock_client = False
|
197
238
|
if storage_config is not None:
|
198
239
|
rank, is_mla_model, is_page_first_layout = (
|
199
240
|
storage_config.tp_rank,
|
200
241
|
storage_config.is_mla_model,
|
201
242
|
storage_config.is_page_first_layout,
|
202
243
|
)
|
244
|
+
|
245
|
+
if storage_config.extra_config is not None:
|
246
|
+
use_mock_client = storage_config.extra_config.get(
|
247
|
+
"use_mock_hf3fs_client", False
|
248
|
+
)
|
203
249
|
else:
|
204
|
-
rank, is_mla_model, is_page_first_layout =
|
250
|
+
rank, is_mla_model, is_page_first_layout = (
|
251
|
+
0,
|
252
|
+
False,
|
253
|
+
False,
|
254
|
+
)
|
205
255
|
|
206
256
|
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
|
207
257
|
|
@@ -220,6 +270,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
220
270
|
dtype=dtype,
|
221
271
|
metadata_client=Hf3fsLocalMetadataClient(),
|
222
272
|
is_page_first_layout=is_page_first_layout,
|
273
|
+
use_mock_client=use_mock_client,
|
223
274
|
)
|
224
275
|
|
225
276
|
try:
|
@@ -269,6 +320,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
269
320
|
metadata_client=metadata_client,
|
270
321
|
is_mla_model=is_mla_model,
|
271
322
|
is_page_first_layout=is_page_first_layout,
|
323
|
+
use_mock_client=use_mock_client,
|
272
324
|
)
|
273
325
|
|
274
326
|
def get(
|
@@ -308,6 +360,8 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
308
360
|
for _ in range(len(batch_indices))
|
309
361
|
]
|
310
362
|
|
363
|
+
start_time = time.perf_counter()
|
364
|
+
|
311
365
|
futures = [
|
312
366
|
self.executor.submit(
|
313
367
|
self.clients[self.ac.next()].batch_read,
|
@@ -318,6 +372,13 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
318
372
|
]
|
319
373
|
read_results = [result for future in futures for result in future.result()]
|
320
374
|
|
375
|
+
end_time = time.perf_counter()
|
376
|
+
ionum = len(batch_indices)
|
377
|
+
self.prefetch_pgs.append(ionum)
|
378
|
+
self.prefetch_bandwidth.append(
|
379
|
+
ionum / (end_time - start_time) * self.gb_per_page
|
380
|
+
)
|
381
|
+
|
321
382
|
results = [None] * len(keys)
|
322
383
|
for batch_index, file_result, read_result in zip(
|
323
384
|
batch_indices, file_results, read_results
|
@@ -345,6 +406,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
345
406
|
[target_sizes] if target_sizes is not None else None,
|
346
407
|
)
|
347
408
|
|
409
|
+
@synchronized()
|
348
410
|
def batch_set(
|
349
411
|
self,
|
350
412
|
keys: List[str],
|
@@ -374,6 +436,8 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
374
436
|
assert value.is_contiguous()
|
375
437
|
file_values.append(value)
|
376
438
|
|
439
|
+
start_time = time.perf_counter()
|
440
|
+
|
377
441
|
futures = [
|
378
442
|
self.executor.submit(
|
379
443
|
self.clients[self.ac.next()].batch_write,
|
@@ -388,6 +452,11 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
388
452
|
for result in future.result()
|
389
453
|
]
|
390
454
|
|
455
|
+
end_time = time.perf_counter()
|
456
|
+
ionum = len(batch_indices)
|
457
|
+
self.backup_pgs.append(ionum)
|
458
|
+
self.backup_bandwidth.append(ionum / (end_time - start_time) * self.gb_per_page)
|
459
|
+
|
391
460
|
written_keys_to_confirm = []
|
392
461
|
results = [index[0] for index in indices]
|
393
462
|
for batch_index, write_result in zip(batch_indices, write_results):
|
@@ -439,3 +508,16 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
439
508
|
except Exception as e:
|
440
509
|
logger.error(f"close HiCacheHF3FS: {e}")
|
441
510
|
logger.info("close HiCacheHF3FS")
|
511
|
+
|
512
|
+
@synchronized()
|
513
|
+
def get_stats(self):
|
514
|
+
storage_metrics = StorageMetrics()
|
515
|
+
storage_metrics.prefetch_pgs.extend(self.prefetch_pgs)
|
516
|
+
storage_metrics.backup_pgs.extend(self.backup_pgs)
|
517
|
+
storage_metrics.prefetch_bandwidth.extend(self.prefetch_bandwidth)
|
518
|
+
storage_metrics.backup_bandwidth.extend(self.backup_bandwidth)
|
519
|
+
self.prefetch_pgs.clear()
|
520
|
+
self.backup_pgs.clear()
|
521
|
+
self.prefetch_bandwidth.clear()
|
522
|
+
self.backup_bandwidth.clear()
|
523
|
+
return storage_metrics
|