sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -3,16 +3,17 @@ import logging
|
|
3
3
|
import threading
|
4
4
|
from enum import IntEnum
|
5
5
|
from functools import wraps
|
6
|
+
from typing import Optional
|
6
7
|
|
7
8
|
import psutil
|
8
9
|
import torch
|
9
10
|
|
10
|
-
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
11
11
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
12
|
-
from sglang.srt.utils import is_npu
|
12
|
+
from sglang.srt.utils import is_npu, is_xpu
|
13
13
|
|
14
14
|
_is_npu = is_npu()
|
15
|
-
|
15
|
+
_is_xpu = is_xpu()
|
16
|
+
if not (_is_npu or _is_xpu):
|
16
17
|
from sgl_kernel.kvcacheio import (
|
17
18
|
transfer_kv_all_layer,
|
18
19
|
transfer_kv_all_layer_lf_pf,
|
@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC):
|
|
169
170
|
return len(self.free_slots)
|
170
171
|
|
171
172
|
@synchronized()
|
172
|
-
def alloc(self, need_size: int) -> torch.Tensor:
|
173
|
+
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
|
173
174
|
assert (
|
174
175
|
need_size % self.page_size == 0
|
175
176
|
), "The requested size should be a multiple of the page size."
|
@@ -464,10 +465,11 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
464
465
|
else:
|
465
466
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
466
467
|
|
467
|
-
def get_buffer_meta(self, keys, indices):
|
468
|
+
def get_buffer_meta(self, keys, indices, local_rank):
|
468
469
|
ptr_list = []
|
469
470
|
key_list = []
|
470
471
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
472
|
+
indices = indices.tolist()
|
471
473
|
v_offset = (
|
472
474
|
self.layer_num
|
473
475
|
* self.size
|
@@ -488,8 +490,8 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
488
490
|
ptr_list.append(k_ptr)
|
489
491
|
ptr_list.append(v_ptr)
|
490
492
|
key_ = keys[index // self.page_size]
|
491
|
-
key_list.append(f"{key_}_{
|
492
|
-
key_list.append(f"{key_}_{
|
493
|
+
key_list.append(f"{key_}_{local_rank}_k")
|
494
|
+
key_list.append(f"{key_}_{local_rank}_v")
|
493
495
|
element_size = (
|
494
496
|
self.layer_num
|
495
497
|
* self.dtype.itemsize
|
@@ -500,20 +502,23 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
500
502
|
element_size_list = [element_size] * len(key_list)
|
501
503
|
return key_list, ptr_list, element_size_list
|
502
504
|
|
503
|
-
def get_buffer_with_hash(self, keys, indices):
|
505
|
+
def get_buffer_with_hash(self, keys, indices=None):
|
504
506
|
assert self.layout == "page_first"
|
505
|
-
assert len(keys) == (len(indices) // self.page_size)
|
507
|
+
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
506
508
|
|
507
509
|
key_list = []
|
508
510
|
buf_list = []
|
509
511
|
|
510
|
-
for
|
512
|
+
for i in range(len(keys)):
|
513
|
+
key = keys[i]
|
511
514
|
key_list.append(f"{key}-k")
|
512
|
-
buf_list.append(self.k_buffer[i : i + self.page_size])
|
513
515
|
key_list.append(f"{key}-v")
|
514
|
-
|
516
|
+
if indices is not None:
|
517
|
+
index = indices[i * self.page_size]
|
518
|
+
buf_list.append(self.k_buffer[index : index + self.page_size])
|
519
|
+
buf_list.append(self.v_buffer[index : index + self.page_size])
|
515
520
|
|
516
|
-
return key_list, buf_list
|
521
|
+
return key_list, buf_list, 2
|
517
522
|
|
518
523
|
|
519
524
|
class MLATokenToKVPoolHost(HostKVCache):
|
@@ -703,10 +708,11 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
703
708
|
else:
|
704
709
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
705
710
|
|
706
|
-
def get_buffer_meta(self, keys, indices):
|
711
|
+
def get_buffer_meta(self, keys, indices, local_rank):
|
707
712
|
ptr_list = []
|
708
713
|
key_list = []
|
709
714
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
715
|
+
indices = indices.tolist()
|
710
716
|
for index in range(0, len(indices), self.page_size):
|
711
717
|
k_ptr = (
|
712
718
|
kv_buffer_data_ptr
|
@@ -727,13 +733,15 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
727
733
|
element_size_list = [element_size] * len(key_list)
|
728
734
|
return key_list, ptr_list, element_size_list
|
729
735
|
|
730
|
-
def get_buffer_with_hash(self, keys, indices):
|
736
|
+
def get_buffer_with_hash(self, keys, indices=None):
|
731
737
|
assert self.layout == "page_first"
|
732
|
-
assert len(keys) == (len(indices) // self.page_size)
|
738
|
+
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
733
739
|
|
734
740
|
buf_list = []
|
735
741
|
|
736
|
-
|
737
|
-
|
742
|
+
if indices is not None:
|
743
|
+
for i in range(len(keys)):
|
744
|
+
index = indices[i * self.page_size]
|
745
|
+
buf_list.append(self.kv_buffer[index : index + self.page_size])
|
738
746
|
|
739
|
-
return keys, buf_list
|
747
|
+
return keys, buf_list, 1
|
@@ -53,8 +53,6 @@ class TreeNode:
|
|
53
53
|
self.last_access_time = time.monotonic()
|
54
54
|
|
55
55
|
self.hit_count = 0
|
56
|
-
# indicating the node is loading KV cache from host
|
57
|
-
self.loading = False
|
58
56
|
# indicating the node is locked to protect from eviction
|
59
57
|
# incremented when the node is referenced by a storage operation
|
60
58
|
self.host_ref_counter = 0
|
@@ -62,7 +60,6 @@ class TreeNode:
|
|
62
60
|
self.host_value: Optional[torch.Tensor] = None
|
63
61
|
# store hash values of each pages
|
64
62
|
self.hash_value: Optional[List[str]] = None
|
65
|
-
self.backuped_storage = False
|
66
63
|
|
67
64
|
self.id = TreeNode.counter if id is None else id
|
68
65
|
TreeNode.counter += 1
|
@@ -152,6 +149,7 @@ class RadixCache(BasePrefixCache):
|
|
152
149
|
self.root_node = TreeNode()
|
153
150
|
self.root_node.key = []
|
154
151
|
self.root_node.value = []
|
152
|
+
self.root_node.host_value = []
|
155
153
|
self.root_node.lock_ref = 1
|
156
154
|
self.evictable_size_ = 0
|
157
155
|
self.protected_size_ = 0
|
@@ -194,7 +192,7 @@ class RadixCache(BasePrefixCache):
|
|
194
192
|
last_host_node=last_node,
|
195
193
|
)
|
196
194
|
|
197
|
-
def insert(self, key: List, value=None):
|
195
|
+
def insert(self, key: List, value=None, chunked=False):
|
198
196
|
if self.disable:
|
199
197
|
return 0
|
200
198
|
|
@@ -239,7 +237,7 @@ class RadixCache(BasePrefixCache):
|
|
239
237
|
self.req_to_token_pool.free(req.req_pool_idx)
|
240
238
|
self.dec_lock_ref(req.last_node)
|
241
239
|
|
242
|
-
def cache_unfinished_req(self, req: Req):
|
240
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
243
241
|
"""Cache request when it is unfinished."""
|
244
242
|
if self.disable:
|
245
243
|
return
|
@@ -260,7 +258,9 @@ class RadixCache(BasePrefixCache):
|
|
260
258
|
page_aligned_token_ids = token_ids[:page_aligned_len]
|
261
259
|
|
262
260
|
# Radix Cache takes one ref in memory pool
|
263
|
-
new_prefix_len = self.insert(
|
261
|
+
new_prefix_len = self.insert(
|
262
|
+
page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
|
263
|
+
)
|
264
264
|
self.token_to_kv_pool_allocator.free(
|
265
265
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
266
266
|
)
|
@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
|
|
181
181
|
self.dec_lock_ref(req.last_node)
|
182
182
|
self.req_to_token_pool.free(req.req_pool_idx)
|
183
183
|
|
184
|
-
def cache_unfinished_req(self, req: Req):
|
184
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
185
185
|
"""Cache request when it is unfinished."""
|
186
186
|
assert req.req_pool_idx is not None
|
187
187
|
token_ids = req.fill_ids
|
@@ -0,0 +1,164 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import threading
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
class Hf3fsClient(ABC):
|
11
|
+
"""Abstract interface for HF3FS clients."""
|
12
|
+
|
13
|
+
@abstractmethod
|
14
|
+
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
15
|
+
"""Initialize the HF3FS client.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
path: File path for storage
|
19
|
+
size: Total size of storage file
|
20
|
+
bytes_per_page: Bytes per page
|
21
|
+
entries: Number of entries for batch operations
|
22
|
+
"""
|
23
|
+
pass
|
24
|
+
|
25
|
+
@abstractmethod
|
26
|
+
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
27
|
+
"""Batch read from storage."""
|
28
|
+
pass
|
29
|
+
|
30
|
+
@abstractmethod
|
31
|
+
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
32
|
+
"""Batch write to storage."""
|
33
|
+
pass
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
|
37
|
+
"""Validate batch operation parameters."""
|
38
|
+
pass
|
39
|
+
|
40
|
+
@abstractmethod
|
41
|
+
def get_size(self) -> int:
|
42
|
+
"""Get total storage size."""
|
43
|
+
pass
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def close(self) -> None:
|
47
|
+
"""Close the client and cleanup resources."""
|
48
|
+
pass
|
49
|
+
|
50
|
+
@abstractmethod
|
51
|
+
def flush(self) -> None:
|
52
|
+
"""Flush data to disk."""
|
53
|
+
pass
|
54
|
+
|
55
|
+
|
56
|
+
logger = logging.getLogger(__name__)
|
57
|
+
|
58
|
+
|
59
|
+
class Hf3fsMockClient(Hf3fsClient):
|
60
|
+
"""Mock implementation of Hf3fsClient for CI testing purposes."""
|
61
|
+
|
62
|
+
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
63
|
+
"""Initialize mock HF3FS client."""
|
64
|
+
self.path = path
|
65
|
+
self.size = size
|
66
|
+
self.bytes_per_page = bytes_per_page
|
67
|
+
self.entries = entries
|
68
|
+
|
69
|
+
# Create directory if it doesn't exist
|
70
|
+
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
71
|
+
|
72
|
+
# Create and initialize the file
|
73
|
+
self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
|
74
|
+
os.ftruncate(self.file, size)
|
75
|
+
|
76
|
+
logger.info(
|
77
|
+
f"Hf3fsMockClient initialized: path={path}, size={size}, "
|
78
|
+
f"bytes_per_page={bytes_per_page}, entries={entries}"
|
79
|
+
)
|
80
|
+
|
81
|
+
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
82
|
+
"""Batch read from mock storage."""
|
83
|
+
self.check(offsets, tensors)
|
84
|
+
|
85
|
+
results = []
|
86
|
+
|
87
|
+
for offset, tensor in zip(offsets, tensors):
|
88
|
+
size = tensor.numel() * tensor.itemsize
|
89
|
+
|
90
|
+
try:
|
91
|
+
os.lseek(self.file, offset, os.SEEK_SET)
|
92
|
+
bytes_read = os.read(self.file, size)
|
93
|
+
|
94
|
+
if len(bytes_read) == size:
|
95
|
+
# Convert bytes to tensor and copy to target
|
96
|
+
bytes_tensor = torch.frombuffer(bytes_read, dtype=torch.uint8)
|
97
|
+
typed_tensor = bytes_tensor.view(tensor.dtype).view(tensor.shape)
|
98
|
+
tensor.copy_(typed_tensor)
|
99
|
+
results.append(size)
|
100
|
+
else:
|
101
|
+
logger.warning(
|
102
|
+
f"Short read: expected {size}, got {len(bytes_read)}"
|
103
|
+
)
|
104
|
+
results.append(len(bytes_read))
|
105
|
+
|
106
|
+
except Exception as e:
|
107
|
+
logger.error(f"Error reading from offset {offset}: {e}")
|
108
|
+
results.append(0)
|
109
|
+
|
110
|
+
return results
|
111
|
+
|
112
|
+
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
113
|
+
"""Batch write to mock storage."""
|
114
|
+
self.check(offsets, tensors)
|
115
|
+
|
116
|
+
results = []
|
117
|
+
|
118
|
+
for offset, tensor in zip(offsets, tensors):
|
119
|
+
size = tensor.numel() * tensor.itemsize
|
120
|
+
|
121
|
+
try:
|
122
|
+
# Convert tensor to bytes and write directly to file
|
123
|
+
tensor_bytes = tensor.contiguous().view(torch.uint8).flatten()
|
124
|
+
data = tensor_bytes.numpy().tobytes()
|
125
|
+
|
126
|
+
os.lseek(self.file, offset, os.SEEK_SET)
|
127
|
+
bytes_written = os.write(self.file, data)
|
128
|
+
|
129
|
+
if bytes_written == size:
|
130
|
+
results.append(size)
|
131
|
+
else:
|
132
|
+
logger.warning(f"Short write: expected {size}, got {bytes_written}")
|
133
|
+
results.append(bytes_written)
|
134
|
+
|
135
|
+
except Exception as e:
|
136
|
+
logger.error(f"Error writing to offset {offset}: {e}")
|
137
|
+
results.append(0)
|
138
|
+
|
139
|
+
return results
|
140
|
+
|
141
|
+
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
|
142
|
+
"""Validate batch operation parameters."""
|
143
|
+
pass
|
144
|
+
|
145
|
+
def get_size(self) -> int:
|
146
|
+
"""Get total storage size."""
|
147
|
+
return self.size
|
148
|
+
|
149
|
+
def close(self) -> None:
|
150
|
+
"""Close the mock client and cleanup resources."""
|
151
|
+
try:
|
152
|
+
if hasattr(self, "file") and self.file >= 0:
|
153
|
+
os.close(self.file)
|
154
|
+
self.file = -1 # Mark as closed
|
155
|
+
logger.info(f"MockHf3fsClient closed: {self.path}")
|
156
|
+
except Exception as e:
|
157
|
+
logger.error(f"Error closing MockHf3fsClient: {e}")
|
158
|
+
|
159
|
+
def flush(self) -> None:
|
160
|
+
"""Flush data to disk."""
|
161
|
+
try:
|
162
|
+
os.fsync(self.file)
|
163
|
+
except Exception as e:
|
164
|
+
logger.error(f"Error flushing MockHf3fsClient: {e}")
|
@@ -9,6 +9,8 @@ from typing import List
|
|
9
9
|
import torch
|
10
10
|
from torch.utils.cpp_extension import load
|
11
11
|
|
12
|
+
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
|
13
|
+
|
12
14
|
root = Path(__file__).parent.resolve()
|
13
15
|
hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
|
14
16
|
|
@@ -51,7 +53,9 @@ def wsynchronized():
|
|
51
53
|
return _decorator
|
52
54
|
|
53
55
|
|
54
|
-
class Hf3fsClient:
|
56
|
+
class Hf3fsUsrBioClient(Hf3fsClient):
|
57
|
+
"""HF3FS client implementation using usrbio."""
|
58
|
+
|
55
59
|
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
56
60
|
if not HF3FS_AVAILABLE:
|
57
61
|
raise ImportError(
|
@@ -4,10 +4,12 @@ import json
|
|
4
4
|
import logging
|
5
5
|
import threading
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Dict, List, Optional, Tuple
|
7
|
+
from typing import Dict, List, Optional, OrderedDict, Tuple
|
8
8
|
|
9
|
+
import orjson
|
9
10
|
import requests
|
10
|
-
from fastapi import FastAPI, HTTPException, Request,
|
11
|
+
from fastapi import FastAPI, HTTPException, Request, Response
|
12
|
+
from fastapi.responses import ORJSONResponse
|
11
13
|
from requests.adapters import HTTPAdapter
|
12
14
|
from urllib3.util.retry import Retry
|
13
15
|
|
@@ -24,10 +26,10 @@ class RankMetadata:
|
|
24
26
|
"""Holds all metadata for a single rank."""
|
25
27
|
|
26
28
|
def __init__(self, num_pages: int):
|
27
|
-
self.lock = threading.
|
29
|
+
self.lock = threading.Lock()
|
28
30
|
self.num_pages = num_pages
|
29
31
|
self.free_pages: List[int] = list(range(num_pages))
|
30
|
-
self.key_to_index:
|
32
|
+
self.key_to_index: OrderedDict[str, int] = OrderedDict()
|
31
33
|
# Todo: Support multi files for HF3FS
|
32
34
|
|
33
35
|
def exists_keys(self, keys: List[str]) -> List[bool]:
|
@@ -46,16 +48,18 @@ class RankMetadata:
|
|
46
48
|
for i, (key, prefix_key) in enumerate(keys):
|
47
49
|
if key in self.key_to_index:
|
48
50
|
results[i] = (True, self.key_to_index[key])
|
51
|
+
self.key_to_index.move_to_end(key)
|
49
52
|
else:
|
50
53
|
new_keys_to_process.append((i, key, prefix_key))
|
51
54
|
|
52
55
|
# Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
|
53
56
|
for i, key, prefix_key in new_keys_to_process:
|
54
57
|
if len(self.free_pages) > 0:
|
55
|
-
|
56
|
-
results[i] = (False, page_idx)
|
58
|
+
page_index = self.free_pages.pop()
|
57
59
|
else:
|
58
|
-
|
60
|
+
page_index = self.key_to_index.popitem(last=False)[1]
|
61
|
+
|
62
|
+
results[i] = (False, page_index)
|
59
63
|
|
60
64
|
return results
|
61
65
|
|
@@ -68,6 +72,7 @@ class RankMetadata:
|
|
68
72
|
with self.lock:
|
69
73
|
for key, page_index in written_keys_to_confirm:
|
70
74
|
self.key_to_index[key] = page_index
|
75
|
+
self.key_to_index.move_to_end(key)
|
71
76
|
|
72
77
|
for page_index in pages_to_release:
|
73
78
|
if page_index not in self.free_pages:
|
@@ -94,7 +99,14 @@ class RankMetadata:
|
|
94
99
|
def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
|
95
100
|
"""Get page indices for keys."""
|
96
101
|
with self.lock:
|
97
|
-
|
102
|
+
results = []
|
103
|
+
for key in keys:
|
104
|
+
if key in self.key_to_index:
|
105
|
+
results.append(self.key_to_index[key])
|
106
|
+
self.key_to_index.move_to_end(key)
|
107
|
+
else:
|
108
|
+
results.append(None)
|
109
|
+
return results
|
98
110
|
|
99
111
|
|
100
112
|
class GlobalMetadataState:
|
@@ -182,7 +194,8 @@ class Hf3fsMetadataServer:
|
|
182
194
|
|
183
195
|
def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
|
184
196
|
self.state = GlobalMetadataState(persistence_path, save_interval)
|
185
|
-
self.app = FastAPI()
|
197
|
+
self.app = FastAPI(default_response_class=ORJSONResponse)
|
198
|
+
|
186
199
|
self._setup_routes()
|
187
200
|
|
188
201
|
def _setup_routes(self):
|
@@ -199,17 +212,25 @@ class Hf3fsMetadataServer:
|
|
199
212
|
|
200
213
|
def get_rank_metadata(self, rank: int) -> RankMetadata:
|
201
214
|
"""Get rank metadata with proper error handling."""
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
215
|
+
if rank not in self.state.ranks:
|
216
|
+
raise HTTPException(
|
217
|
+
status_code=404,
|
218
|
+
detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.",
|
219
|
+
)
|
220
|
+
return self.state.ranks[rank]
|
221
|
+
|
222
|
+
async def _read_json(self, request: Request) -> dict:
|
223
|
+
"""Parse request JSON using orjson if available."""
|
224
|
+
body = await request.body()
|
225
|
+
return orjson.loads(body)
|
226
|
+
|
227
|
+
def _json_response(self, content: dict):
|
228
|
+
"""Return ORJSONResponse when available to bypass jsonable_encoder."""
|
229
|
+
return ORJSONResponse(content)
|
209
230
|
|
210
231
|
async def initialize(self, rank: int, request: Request):
|
211
232
|
"""Initialize a rank with specified number of pages."""
|
212
|
-
data = await
|
233
|
+
data = await self._read_json(request)
|
213
234
|
num_pages = data["num_pages"]
|
214
235
|
with self.state.global_lock:
|
215
236
|
if rank in self.state.ranks:
|
@@ -223,57 +244,55 @@ class Hf3fsMetadataServer:
|
|
223
244
|
else:
|
224
245
|
logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
|
225
246
|
self.state.ranks[rank] = RankMetadata(num_pages)
|
226
|
-
return
|
247
|
+
return Response(status_code=204)
|
227
248
|
|
228
249
|
async def exists(self, rank: int, request: Request):
|
229
250
|
"""Check if keys exist in metadata."""
|
230
|
-
data = await
|
251
|
+
data = await self._read_json(request)
|
231
252
|
keys = data["keys"]
|
232
253
|
metadata = self.get_rank_metadata(rank)
|
233
254
|
results = metadata.exists_keys(keys)
|
234
|
-
return {"exists": results}
|
255
|
+
return self._json_response({"exists": results})
|
235
256
|
|
236
257
|
async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
|
237
258
|
"""Reserve and allocate page indices for keys."""
|
238
|
-
data = await
|
259
|
+
data = await self._read_json(request)
|
239
260
|
metadata = self.get_rank_metadata(rank)
|
240
261
|
keys = data["keys"]
|
241
262
|
results = metadata.reserve_and_allocate_page_indices(keys)
|
242
|
-
return {"indices": results}
|
263
|
+
return self._json_response({"indices": results})
|
243
264
|
|
244
265
|
async def confirm_write(self, rank: int, request: Request):
|
245
266
|
"""Confirm write operations and release pages."""
|
246
|
-
data = await
|
267
|
+
data = await self._read_json(request)
|
247
268
|
metadata = self.get_rank_metadata(rank)
|
248
269
|
success_written_keys = data.get("written_keys_to_confirm", [])
|
249
270
|
released_pages = data.get("pages_to_release", [])
|
250
271
|
|
251
272
|
metadata.confirm_write(success_written_keys, released_pages)
|
252
273
|
|
253
|
-
return
|
254
|
-
"message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
|
255
|
-
}
|
274
|
+
return Response(status_code=204)
|
256
275
|
|
257
276
|
async def delete_keys(self, rank: int, request: Request):
|
258
277
|
"""Delete keys from metadata."""
|
259
|
-
data = await
|
278
|
+
data = await self._read_json(request)
|
260
279
|
metadata = self.get_rank_metadata(rank)
|
261
280
|
count = metadata.delete_keys(data["keys"])
|
262
|
-
return
|
281
|
+
return Response(status_code=204)
|
263
282
|
|
264
283
|
async def clear(self, rank: int):
|
265
284
|
"""Clear all metadata for a rank."""
|
266
285
|
metadata = self.get_rank_metadata(rank)
|
267
286
|
metadata.clear_all()
|
268
|
-
return
|
287
|
+
return Response(status_code=204)
|
269
288
|
|
270
289
|
async def get_page_indices(self, rank: int, request: Request):
|
271
290
|
"""Get page indices for keys."""
|
272
|
-
data = await
|
291
|
+
data = await self._read_json(request)
|
273
292
|
metadata = self.get_rank_metadata(rank)
|
274
293
|
keys = data["keys"]
|
275
294
|
results = metadata.get_page_indices(keys)
|
276
|
-
return {"indices": results}
|
295
|
+
return self._json_response({"indices": results})
|
277
296
|
|
278
297
|
def run(self, host: str = "0.0.0.0", port: int = 18000):
|
279
298
|
"""Run the metadata server."""
|
@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
|
|
309
328
|
status_forcelist=[500, 502, 503, 504],
|
310
329
|
allowed_methods=["GET", "POST"],
|
311
330
|
)
|
312
|
-
adapter = HTTPAdapter(
|
331
|
+
adapter = HTTPAdapter(
|
332
|
+
max_retries=retry_strategy, pool_connections=256, pool_maxsize=256
|
333
|
+
)
|
313
334
|
self._session.mount("http://", adapter)
|
314
335
|
|
315
336
|
def _post(self, endpoint: str, json_data: dict) -> dict:
|
316
337
|
try:
|
317
|
-
|
338
|
+
url = f"{self.base_url}/{endpoint}"
|
339
|
+
headers = {"Content-Type": "application/json"}
|
340
|
+
payload = orjson.dumps(json_data) # type: ignore[union-attr]
|
341
|
+
response = self._session.post(url, data=payload, headers=headers)
|
318
342
|
response.raise_for_status()
|
319
|
-
|
343
|
+
|
344
|
+
if response.status_code == 204 or not response.content:
|
345
|
+
return {}
|
346
|
+
return orjson.loads(response.content) # type: ignore[union-attr]
|
320
347
|
except requests.exceptions.RequestException as e:
|
321
348
|
logging.error(f"Failed to POST to {endpoint} after retries: {e}")
|
322
349
|
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
|