sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,164 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import threading
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
class Hf3fsClient(ABC):
|
11
|
+
"""Abstract interface for HF3FS clients."""
|
12
|
+
|
13
|
+
@abstractmethod
|
14
|
+
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
15
|
+
"""Initialize the HF3FS client.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
path: File path for storage
|
19
|
+
size: Total size of storage file
|
20
|
+
bytes_per_page: Bytes per page
|
21
|
+
entries: Number of entries for batch operations
|
22
|
+
"""
|
23
|
+
pass
|
24
|
+
|
25
|
+
@abstractmethod
|
26
|
+
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
27
|
+
"""Batch read from storage."""
|
28
|
+
pass
|
29
|
+
|
30
|
+
@abstractmethod
|
31
|
+
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
32
|
+
"""Batch write to storage."""
|
33
|
+
pass
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
|
37
|
+
"""Validate batch operation parameters."""
|
38
|
+
pass
|
39
|
+
|
40
|
+
@abstractmethod
|
41
|
+
def get_size(self) -> int:
|
42
|
+
"""Get total storage size."""
|
43
|
+
pass
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def close(self) -> None:
|
47
|
+
"""Close the client and cleanup resources."""
|
48
|
+
pass
|
49
|
+
|
50
|
+
@abstractmethod
|
51
|
+
def flush(self) -> None:
|
52
|
+
"""Flush data to disk."""
|
53
|
+
pass
|
54
|
+
|
55
|
+
|
56
|
+
logger = logging.getLogger(__name__)
|
57
|
+
|
58
|
+
|
59
|
+
class Hf3fsMockClient(Hf3fsClient):
|
60
|
+
"""Mock implementation of Hf3fsClient for CI testing purposes."""
|
61
|
+
|
62
|
+
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
63
|
+
"""Initialize mock HF3FS client."""
|
64
|
+
self.path = path
|
65
|
+
self.size = size
|
66
|
+
self.bytes_per_page = bytes_per_page
|
67
|
+
self.entries = entries
|
68
|
+
|
69
|
+
# Create directory if it doesn't exist
|
70
|
+
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
71
|
+
|
72
|
+
# Create and initialize the file
|
73
|
+
self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
|
74
|
+
os.ftruncate(self.file, size)
|
75
|
+
|
76
|
+
logger.info(
|
77
|
+
f"Hf3fsMockClient initialized: path={path}, size={size}, "
|
78
|
+
f"bytes_per_page={bytes_per_page}, entries={entries}"
|
79
|
+
)
|
80
|
+
|
81
|
+
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
82
|
+
"""Batch read from mock storage."""
|
83
|
+
self.check(offsets, tensors)
|
84
|
+
|
85
|
+
results = []
|
86
|
+
|
87
|
+
for offset, tensor in zip(offsets, tensors):
|
88
|
+
size = tensor.numel() * tensor.itemsize
|
89
|
+
|
90
|
+
try:
|
91
|
+
os.lseek(self.file, offset, os.SEEK_SET)
|
92
|
+
bytes_read = os.read(self.file, size)
|
93
|
+
|
94
|
+
if len(bytes_read) == size:
|
95
|
+
# Convert bytes to tensor and copy to target
|
96
|
+
bytes_tensor = torch.frombuffer(bytes_read, dtype=torch.uint8)
|
97
|
+
typed_tensor = bytes_tensor.view(tensor.dtype).view(tensor.shape)
|
98
|
+
tensor.copy_(typed_tensor)
|
99
|
+
results.append(size)
|
100
|
+
else:
|
101
|
+
logger.warning(
|
102
|
+
f"Short read: expected {size}, got {len(bytes_read)}"
|
103
|
+
)
|
104
|
+
results.append(len(bytes_read))
|
105
|
+
|
106
|
+
except Exception as e:
|
107
|
+
logger.error(f"Error reading from offset {offset}: {e}")
|
108
|
+
results.append(0)
|
109
|
+
|
110
|
+
return results
|
111
|
+
|
112
|
+
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
113
|
+
"""Batch write to mock storage."""
|
114
|
+
self.check(offsets, tensors)
|
115
|
+
|
116
|
+
results = []
|
117
|
+
|
118
|
+
for offset, tensor in zip(offsets, tensors):
|
119
|
+
size = tensor.numel() * tensor.itemsize
|
120
|
+
|
121
|
+
try:
|
122
|
+
# Convert tensor to bytes and write directly to file
|
123
|
+
tensor_bytes = tensor.contiguous().view(torch.uint8).flatten()
|
124
|
+
data = tensor_bytes.numpy().tobytes()
|
125
|
+
|
126
|
+
os.lseek(self.file, offset, os.SEEK_SET)
|
127
|
+
bytes_written = os.write(self.file, data)
|
128
|
+
|
129
|
+
if bytes_written == size:
|
130
|
+
results.append(size)
|
131
|
+
else:
|
132
|
+
logger.warning(f"Short write: expected {size}, got {bytes_written}")
|
133
|
+
results.append(bytes_written)
|
134
|
+
|
135
|
+
except Exception as e:
|
136
|
+
logger.error(f"Error writing to offset {offset}: {e}")
|
137
|
+
results.append(0)
|
138
|
+
|
139
|
+
return results
|
140
|
+
|
141
|
+
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
|
142
|
+
"""Validate batch operation parameters."""
|
143
|
+
pass
|
144
|
+
|
145
|
+
def get_size(self) -> int:
|
146
|
+
"""Get total storage size."""
|
147
|
+
return self.size
|
148
|
+
|
149
|
+
def close(self) -> None:
|
150
|
+
"""Close the mock client and cleanup resources."""
|
151
|
+
try:
|
152
|
+
if hasattr(self, "file") and self.file >= 0:
|
153
|
+
os.close(self.file)
|
154
|
+
self.file = -1 # Mark as closed
|
155
|
+
logger.info(f"MockHf3fsClient closed: {self.path}")
|
156
|
+
except Exception as e:
|
157
|
+
logger.error(f"Error closing MockHf3fsClient: {e}")
|
158
|
+
|
159
|
+
def flush(self) -> None:
|
160
|
+
"""Flush data to disk."""
|
161
|
+
try:
|
162
|
+
os.fsync(self.file)
|
163
|
+
except Exception as e:
|
164
|
+
logger.error(f"Error flushing MockHf3fsClient: {e}")
|
@@ -9,6 +9,8 @@ from typing import List
|
|
9
9
|
import torch
|
10
10
|
from torch.utils.cpp_extension import load
|
11
11
|
|
12
|
+
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
|
13
|
+
|
12
14
|
root = Path(__file__).parent.resolve()
|
13
15
|
hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
|
14
16
|
|
@@ -51,7 +53,9 @@ def wsynchronized():
|
|
51
53
|
return _decorator
|
52
54
|
|
53
55
|
|
54
|
-
class Hf3fsClient:
|
56
|
+
class Hf3fsUsrBioClient(Hf3fsClient):
|
57
|
+
"""HF3FS client implementation using usrbio."""
|
58
|
+
|
55
59
|
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
56
60
|
if not HF3FS_AVAILABLE:
|
57
61
|
raise ImportError(
|
@@ -4,10 +4,12 @@ import json
|
|
4
4
|
import logging
|
5
5
|
import threading
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Dict, List, Optional, Tuple
|
7
|
+
from typing import Dict, List, Optional, OrderedDict, Tuple
|
8
8
|
|
9
|
+
import orjson
|
9
10
|
import requests
|
10
|
-
from fastapi import FastAPI, HTTPException, Request,
|
11
|
+
from fastapi import FastAPI, HTTPException, Request, Response
|
12
|
+
from fastapi.responses import ORJSONResponse
|
11
13
|
from requests.adapters import HTTPAdapter
|
12
14
|
from urllib3.util.retry import Retry
|
13
15
|
|
@@ -24,10 +26,10 @@ class RankMetadata:
|
|
24
26
|
"""Holds all metadata for a single rank."""
|
25
27
|
|
26
28
|
def __init__(self, num_pages: int):
|
27
|
-
self.lock = threading.
|
29
|
+
self.lock = threading.Lock()
|
28
30
|
self.num_pages = num_pages
|
29
31
|
self.free_pages: List[int] = list(range(num_pages))
|
30
|
-
self.key_to_index:
|
32
|
+
self.key_to_index: OrderedDict[str, int] = OrderedDict()
|
31
33
|
# Todo: Support multi files for HF3FS
|
32
34
|
|
33
35
|
def exists_keys(self, keys: List[str]) -> List[bool]:
|
@@ -46,16 +48,18 @@ class RankMetadata:
|
|
46
48
|
for i, (key, prefix_key) in enumerate(keys):
|
47
49
|
if key in self.key_to_index:
|
48
50
|
results[i] = (True, self.key_to_index[key])
|
51
|
+
self.key_to_index.move_to_end(key)
|
49
52
|
else:
|
50
53
|
new_keys_to_process.append((i, key, prefix_key))
|
51
54
|
|
52
55
|
# Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
|
53
56
|
for i, key, prefix_key in new_keys_to_process:
|
54
57
|
if len(self.free_pages) > 0:
|
55
|
-
|
56
|
-
results[i] = (False, page_idx)
|
58
|
+
page_index = self.free_pages.pop()
|
57
59
|
else:
|
58
|
-
|
60
|
+
page_index = self.key_to_index.popitem(last=False)[1]
|
61
|
+
|
62
|
+
results[i] = (False, page_index)
|
59
63
|
|
60
64
|
return results
|
61
65
|
|
@@ -68,6 +72,7 @@ class RankMetadata:
|
|
68
72
|
with self.lock:
|
69
73
|
for key, page_index in written_keys_to_confirm:
|
70
74
|
self.key_to_index[key] = page_index
|
75
|
+
self.key_to_index.move_to_end(key)
|
71
76
|
|
72
77
|
for page_index in pages_to_release:
|
73
78
|
if page_index not in self.free_pages:
|
@@ -94,7 +99,14 @@ class RankMetadata:
|
|
94
99
|
def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
|
95
100
|
"""Get page indices for keys."""
|
96
101
|
with self.lock:
|
97
|
-
|
102
|
+
results = []
|
103
|
+
for key in keys:
|
104
|
+
if key in self.key_to_index:
|
105
|
+
results.append(self.key_to_index[key])
|
106
|
+
self.key_to_index.move_to_end(key)
|
107
|
+
else:
|
108
|
+
results.append(None)
|
109
|
+
return results
|
98
110
|
|
99
111
|
|
100
112
|
class GlobalMetadataState:
|
@@ -182,7 +194,8 @@ class Hf3fsMetadataServer:
|
|
182
194
|
|
183
195
|
def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
|
184
196
|
self.state = GlobalMetadataState(persistence_path, save_interval)
|
185
|
-
self.app = FastAPI()
|
197
|
+
self.app = FastAPI(default_response_class=ORJSONResponse)
|
198
|
+
|
186
199
|
self._setup_routes()
|
187
200
|
|
188
201
|
def _setup_routes(self):
|
@@ -199,17 +212,25 @@ class Hf3fsMetadataServer:
|
|
199
212
|
|
200
213
|
def get_rank_metadata(self, rank: int) -> RankMetadata:
|
201
214
|
"""Get rank metadata with proper error handling."""
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
215
|
+
if rank not in self.state.ranks:
|
216
|
+
raise HTTPException(
|
217
|
+
status_code=404,
|
218
|
+
detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.",
|
219
|
+
)
|
220
|
+
return self.state.ranks[rank]
|
221
|
+
|
222
|
+
async def _read_json(self, request: Request) -> dict:
|
223
|
+
"""Parse request JSON using orjson if available."""
|
224
|
+
body = await request.body()
|
225
|
+
return orjson.loads(body)
|
226
|
+
|
227
|
+
def _json_response(self, content: dict):
|
228
|
+
"""Return ORJSONResponse when available to bypass jsonable_encoder."""
|
229
|
+
return ORJSONResponse(content)
|
209
230
|
|
210
231
|
async def initialize(self, rank: int, request: Request):
|
211
232
|
"""Initialize a rank with specified number of pages."""
|
212
|
-
data = await
|
233
|
+
data = await self._read_json(request)
|
213
234
|
num_pages = data["num_pages"]
|
214
235
|
with self.state.global_lock:
|
215
236
|
if rank in self.state.ranks:
|
@@ -223,57 +244,55 @@ class Hf3fsMetadataServer:
|
|
223
244
|
else:
|
224
245
|
logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
|
225
246
|
self.state.ranks[rank] = RankMetadata(num_pages)
|
226
|
-
return
|
247
|
+
return Response(status_code=204)
|
227
248
|
|
228
249
|
async def exists(self, rank: int, request: Request):
|
229
250
|
"""Check if keys exist in metadata."""
|
230
|
-
data = await
|
251
|
+
data = await self._read_json(request)
|
231
252
|
keys = data["keys"]
|
232
253
|
metadata = self.get_rank_metadata(rank)
|
233
254
|
results = metadata.exists_keys(keys)
|
234
|
-
return {"exists": results}
|
255
|
+
return self._json_response({"exists": results})
|
235
256
|
|
236
257
|
async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
|
237
258
|
"""Reserve and allocate page indices for keys."""
|
238
|
-
data = await
|
259
|
+
data = await self._read_json(request)
|
239
260
|
metadata = self.get_rank_metadata(rank)
|
240
261
|
keys = data["keys"]
|
241
262
|
results = metadata.reserve_and_allocate_page_indices(keys)
|
242
|
-
return {"indices": results}
|
263
|
+
return self._json_response({"indices": results})
|
243
264
|
|
244
265
|
async def confirm_write(self, rank: int, request: Request):
|
245
266
|
"""Confirm write operations and release pages."""
|
246
|
-
data = await
|
267
|
+
data = await self._read_json(request)
|
247
268
|
metadata = self.get_rank_metadata(rank)
|
248
269
|
success_written_keys = data.get("written_keys_to_confirm", [])
|
249
270
|
released_pages = data.get("pages_to_release", [])
|
250
271
|
|
251
272
|
metadata.confirm_write(success_written_keys, released_pages)
|
252
273
|
|
253
|
-
return
|
254
|
-
"message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
|
255
|
-
}
|
274
|
+
return Response(status_code=204)
|
256
275
|
|
257
276
|
async def delete_keys(self, rank: int, request: Request):
|
258
277
|
"""Delete keys from metadata."""
|
259
|
-
data = await
|
278
|
+
data = await self._read_json(request)
|
260
279
|
metadata = self.get_rank_metadata(rank)
|
261
280
|
count = metadata.delete_keys(data["keys"])
|
262
|
-
return
|
281
|
+
return Response(status_code=204)
|
263
282
|
|
264
283
|
async def clear(self, rank: int):
|
265
284
|
"""Clear all metadata for a rank."""
|
266
285
|
metadata = self.get_rank_metadata(rank)
|
267
286
|
metadata.clear_all()
|
268
|
-
return
|
287
|
+
return Response(status_code=204)
|
269
288
|
|
270
289
|
async def get_page_indices(self, rank: int, request: Request):
|
271
290
|
"""Get page indices for keys."""
|
272
|
-
data = await
|
291
|
+
data = await self._read_json(request)
|
273
292
|
metadata = self.get_rank_metadata(rank)
|
274
293
|
keys = data["keys"]
|
275
294
|
results = metadata.get_page_indices(keys)
|
276
|
-
return {"indices": results}
|
295
|
+
return self._json_response({"indices": results})
|
277
296
|
|
278
297
|
def run(self, host: str = "0.0.0.0", port: int = 18000):
|
279
298
|
"""Run the metadata server."""
|
@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
|
|
309
328
|
status_forcelist=[500, 502, 503, 504],
|
310
329
|
allowed_methods=["GET", "POST"],
|
311
330
|
)
|
312
|
-
adapter = HTTPAdapter(
|
331
|
+
adapter = HTTPAdapter(
|
332
|
+
max_retries=retry_strategy, pool_connections=256, pool_maxsize=256
|
333
|
+
)
|
313
334
|
self._session.mount("http://", adapter)
|
314
335
|
|
315
336
|
def _post(self, endpoint: str, json_data: dict) -> dict:
|
316
337
|
try:
|
317
|
-
|
338
|
+
url = f"{self.base_url}/{endpoint}"
|
339
|
+
headers = {"Content-Type": "application/json"}
|
340
|
+
payload = orjson.dumps(json_data) # type: ignore[union-attr]
|
341
|
+
response = self._session.post(url, data=payload, headers=headers)
|
318
342
|
response.raise_for_status()
|
319
|
-
|
343
|
+
|
344
|
+
if response.status_code == 204 or not response.content:
|
345
|
+
return {}
|
346
|
+
return orjson.loads(response.content) # type: ignore[union-attr]
|
320
347
|
except requests.exceptions.RequestException as e:
|
321
348
|
logging.error(f"Failed to POST to {endpoint} after retries: {e}")
|
322
349
|
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
|