sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
|
|
20
20
|
MLATokenToKVPoolHost,
|
21
21
|
)
|
22
22
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
23
|
+
from sglang.srt.metrics.collector import StorageMetricsCollector
|
23
24
|
|
24
25
|
logger = logging.getLogger(__name__)
|
25
26
|
|
@@ -37,8 +38,11 @@ class HiRadixCache(RadixCache):
|
|
37
38
|
hicache_write_policy: str,
|
38
39
|
hicache_io_backend: str,
|
39
40
|
hicache_mem_layout: str,
|
41
|
+
enable_metrics: bool,
|
40
42
|
hicache_storage_backend: Optional[str] = None,
|
41
43
|
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
44
|
+
model_name: Optional[str] = None,
|
45
|
+
storage_backend_extra_config: Optional[str] = None,
|
42
46
|
):
|
43
47
|
|
44
48
|
if hicache_io_backend == "direct":
|
@@ -71,6 +75,8 @@ class HiRadixCache(RadixCache):
|
|
71
75
|
self.tp_group = tp_cache_group
|
72
76
|
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
73
77
|
self.enable_storage = hicache_storage_backend is not None
|
78
|
+
self.enable_storage_metrics = self.enable_storage and enable_metrics
|
79
|
+
|
74
80
|
# todo: customizable storage prefetch threshold and timeout
|
75
81
|
self.prefetch_threshold = 256
|
76
82
|
self.prefetch_timeout = 3 # seconds
|
@@ -87,7 +93,17 @@ class HiRadixCache(RadixCache):
|
|
87
93
|
io_backend=hicache_io_backend,
|
88
94
|
storage_backend=hicache_storage_backend,
|
89
95
|
prefetch_threshold=self.prefetch_threshold,
|
96
|
+
model_name=model_name,
|
97
|
+
storage_backend_extra_config=storage_backend_extra_config,
|
90
98
|
)
|
99
|
+
if self.enable_storage_metrics:
|
100
|
+
# TODO: support pp
|
101
|
+
labels = {
|
102
|
+
"storage_backend": hicache_storage_backend,
|
103
|
+
"tp_rank": self.cache_controller.tp_rank,
|
104
|
+
"dp_rank": self.cache_controller.dp_rank,
|
105
|
+
}
|
106
|
+
self.metrics_collector = StorageMetricsCollector(labels=labels)
|
91
107
|
|
92
108
|
# record the nodes with ongoing write through
|
93
109
|
self.ongoing_write_through = {}
|
@@ -98,10 +114,7 @@ class HiRadixCache(RadixCache):
|
|
98
114
|
self.ongoing_backup = {}
|
99
115
|
# todo: dynamically adjust the threshold
|
100
116
|
self.write_through_threshold = (
|
101
|
-
1 if hicache_write_policy == "write_through" else
|
102
|
-
)
|
103
|
-
self.write_through_threshold_storage = (
|
104
|
-
1 if hicache_write_policy == "write_through" else 3
|
117
|
+
1 if hicache_write_policy == "write_through" else 2
|
105
118
|
)
|
106
119
|
self.load_back_threshold = 10
|
107
120
|
super().__init__(
|
@@ -121,6 +134,28 @@ class HiRadixCache(RadixCache):
|
|
121
134
|
height += 1
|
122
135
|
return height
|
123
136
|
|
137
|
+
def clear_storage_backend(self) -> bool:
|
138
|
+
if self.enable_storage:
|
139
|
+
try:
|
140
|
+
# Check if the storage backend has a clear method (for nixl backends)
|
141
|
+
if hasattr(self.cache_controller.storage_backend, "clear"):
|
142
|
+
self.cache_controller.storage_backend.clear()
|
143
|
+
logger.info(
|
144
|
+
"Hierarchical cache storage backend cleared successfully!"
|
145
|
+
)
|
146
|
+
return True
|
147
|
+
else:
|
148
|
+
logger.warning(
|
149
|
+
f"Storage backend {type(self.cache_controller.storage_backend).__name__} does not support clear operation."
|
150
|
+
)
|
151
|
+
return False
|
152
|
+
except Exception as e:
|
153
|
+
logger.error(f"Failed to clear hierarchical cache storage backend: {e}")
|
154
|
+
return False
|
155
|
+
else:
|
156
|
+
logger.warning("Hierarchical cache storage backend is not enabled.")
|
157
|
+
return False
|
158
|
+
|
124
159
|
def write_backup(self, node: TreeNode, write_back=False):
|
125
160
|
host_indices = self.cache_controller.write(
|
126
161
|
device_indices=node.value,
|
@@ -151,8 +186,9 @@ class HiRadixCache(RadixCache):
|
|
151
186
|
self.ongoing_backup[operation_id] = node
|
152
187
|
node.protect_host()
|
153
188
|
|
154
|
-
def
|
155
|
-
|
189
|
+
def _inc_hit_count(self, node: TreeNode, chunked=False):
|
190
|
+
# skip the hit count update for chunked requests
|
191
|
+
if self.cache_controller.write_policy == "write_back" or chunked:
|
156
192
|
return
|
157
193
|
node.hit_count += 1
|
158
194
|
|
@@ -160,51 +196,62 @@ class HiRadixCache(RadixCache):
|
|
160
196
|
if node.hit_count >= self.write_through_threshold:
|
161
197
|
# write to host if the node is not backuped
|
162
198
|
self.write_backup(node)
|
163
|
-
else:
|
164
|
-
if (
|
165
|
-
self.enable_storage
|
166
|
-
and (not node.backuped_storage)
|
167
|
-
and node.hit_count >= self.write_through_threshold_storage
|
168
|
-
):
|
169
|
-
# if the node is backuped on host memory but not on storage
|
170
|
-
self.write_backup_storage(node)
|
171
199
|
|
172
200
|
def writing_check(self, write_back=False):
|
173
201
|
if write_back:
|
174
202
|
# blocking till all write back complete
|
175
203
|
while len(self.ongoing_write_through) > 0:
|
176
|
-
|
177
|
-
|
204
|
+
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
205
|
+
finish_event.synchronize()
|
206
|
+
for ack_id in ack_list:
|
207
|
+
del self.ongoing_write_through[ack_id]
|
208
|
+
self.cache_controller.ack_write_queue.clear()
|
209
|
+
assert len(self.ongoing_write_through) == 0
|
178
210
|
return
|
179
|
-
|
180
|
-
|
181
|
-
)
|
211
|
+
|
212
|
+
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
|
213
|
+
if len(self.ongoing_write_through) == 0:
|
214
|
+
return
|
215
|
+
|
216
|
+
finish_count = 0
|
217
|
+
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
218
|
+
if not finish_event.query():
|
219
|
+
break
|
220
|
+
finish_count += 1
|
221
|
+
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
|
182
222
|
if self.tp_world_size > 1:
|
183
|
-
#
|
223
|
+
# synchronize TP workers to make the same update to radix cache
|
184
224
|
torch.distributed.all_reduce(
|
185
225
|
queue_size,
|
186
226
|
op=torch.distributed.ReduceOp.MIN,
|
187
227
|
group=self.tp_group,
|
188
228
|
)
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
229
|
+
|
230
|
+
finish_count = int(queue_size.item())
|
231
|
+
while finish_count > 0:
|
232
|
+
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
233
|
+
finish_event.synchronize()
|
234
|
+
for ack_id in ack_list:
|
235
|
+
backuped_node = self.ongoing_write_through.pop(ack_id)
|
236
|
+
self.dec_lock_ref(backuped_node)
|
237
|
+
if self.enable_storage:
|
238
|
+
self.write_backup_storage(backuped_node)
|
239
|
+
finish_count -= 1
|
193
240
|
|
194
241
|
def loading_check(self):
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
self.dec_lock_ref(end_node)
|
200
|
-
while end_node != start_node:
|
201
|
-
assert end_node.loading
|
202
|
-
end_node.loading = False
|
203
|
-
end_node = end_node.parent
|
204
|
-
# clear the reference
|
205
|
-
del self.ongoing_load_back[ack_id]
|
206
|
-
except Exception:
|
242
|
+
finish_count = 0
|
243
|
+
for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
|
244
|
+
if not finish_event.query():
|
245
|
+
# the KV cache loading is still ongoing
|
207
246
|
break
|
247
|
+
finish_count += 1
|
248
|
+
# no need to sync across TP workers as batch forwarding is synced
|
249
|
+
for ack_id in ack_list:
|
250
|
+
end_node = self.ongoing_load_back.pop(ack_id)
|
251
|
+
self.dec_lock_ref(end_node)
|
252
|
+
|
253
|
+
# ACK until all events are processed
|
254
|
+
del self.cache_controller.ack_load_queue[:finish_count]
|
208
255
|
|
209
256
|
def evictable_size(self):
|
210
257
|
return self.evictable_size_
|
@@ -329,12 +376,11 @@ class HiRadixCache(RadixCache):
|
|
329
376
|
# no sufficient GPU memory to load back KV caches
|
330
377
|
return None
|
331
378
|
|
332
|
-
self.ongoing_load_back[last_hit_node.id] =
|
379
|
+
self.ongoing_load_back[last_hit_node.id] = last_hit_node
|
333
380
|
offset = 0
|
334
381
|
for node in nodes_to_load:
|
335
382
|
node.value = device_indices[offset : offset + len(node.host_value)]
|
336
383
|
offset += len(node.host_value)
|
337
|
-
node.loading = True
|
338
384
|
self.evictable_size_ += len(device_indices)
|
339
385
|
self.inc_lock_ref(last_hit_node)
|
340
386
|
|
@@ -363,66 +409,72 @@ class HiRadixCache(RadixCache):
|
|
363
409
|
last_node,
|
364
410
|
)
|
365
411
|
|
366
|
-
def ready_to_load_host_cache(self):
|
367
|
-
|
368
|
-
|
369
|
-
|
412
|
+
def ready_to_load_host_cache(self) -> int:
|
413
|
+
"""
|
414
|
+
Notify the cache controller to start the KV cache loading.
|
415
|
+
Return the consumer index for the schedule batch manager to track.
|
416
|
+
"""
|
417
|
+
return self.cache_controller.start_loading()
|
370
418
|
|
371
419
|
def check_hicache_events(self):
|
372
420
|
self.writing_check()
|
373
421
|
self.loading_check()
|
374
422
|
if self.enable_storage:
|
375
|
-
self.
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
queue_size = torch.tensor(
|
380
|
-
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
|
381
|
-
)
|
382
|
-
if self.tp_world_size > 1:
|
383
|
-
# synchrnoize TP workers to make the same update to hiradix cache
|
384
|
-
torch.distributed.all_reduce(
|
385
|
-
queue_size,
|
386
|
-
op=torch.distributed.ReduceOp.MIN,
|
387
|
-
group=self.tp_group,
|
423
|
+
self.drain_storage_control_queues()
|
424
|
+
if self.enable_storage_metrics:
|
425
|
+
self.metrics_collector.log_storage_metrics(
|
426
|
+
self.cache_controller.storage_backend.get_stats()
|
388
427
|
)
|
389
|
-
for _ in range(queue_size.item()):
|
390
|
-
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
391
|
-
if req_id in self.ongoing_prefetch:
|
392
|
-
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
|
393
|
-
last_host_node.release_host()
|
394
|
-
del self.ongoing_prefetch[req_id]
|
395
|
-
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
396
|
-
else:
|
397
|
-
# the revoked operation already got terminated
|
398
|
-
pass
|
399
428
|
|
400
|
-
def
|
401
|
-
|
402
|
-
|
429
|
+
def drain_storage_control_queues(self):
|
430
|
+
"""
|
431
|
+
Combine prefetch revoke, backup ack, and host mem release checks
|
432
|
+
to minimize TP synchronization and Python overhead.
|
433
|
+
"""
|
434
|
+
cc = self.cache_controller
|
435
|
+
|
436
|
+
qsizes = torch.tensor(
|
437
|
+
[
|
438
|
+
cc.prefetch_revoke_queue.qsize(),
|
439
|
+
cc.ack_backup_queue.qsize(),
|
440
|
+
cc.host_mem_release_queue.qsize(),
|
441
|
+
],
|
442
|
+
dtype=torch.int,
|
403
443
|
)
|
404
444
|
if self.tp_world_size > 1:
|
405
|
-
# synchrnoize TP workers to make the same update to hiradix cache
|
406
445
|
torch.distributed.all_reduce(
|
407
|
-
|
408
|
-
op=torch.distributed.ReduceOp.MIN,
|
409
|
-
group=self.tp_group,
|
446
|
+
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
410
447
|
)
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
448
|
+
|
449
|
+
n_revoke, n_backup, n_release = map(int, qsizes.tolist())
|
450
|
+
|
451
|
+
# process prefetch revokes
|
452
|
+
for _ in range(n_revoke):
|
453
|
+
req_id = cc.prefetch_revoke_queue.get()
|
454
|
+
info = self.ongoing_prefetch.pop(req_id, None)
|
455
|
+
if info is not None:
|
456
|
+
last_host_node, token_ids, _, _ = info
|
457
|
+
last_host_node.release_host()
|
458
|
+
cc.prefetch_tokens_occupied -= len(token_ids)
|
459
|
+
# else: the revoked operation already got terminated, nothing to do
|
460
|
+
|
461
|
+
# process backup acks
|
462
|
+
for _ in range(n_backup):
|
463
|
+
operation = cc.ack_backup_queue.get()
|
464
|
+
ack_id = operation.id
|
465
|
+
entry = self.ongoing_backup.pop(ack_id, None)
|
466
|
+
if entry is not None:
|
467
|
+
entry.release_host()
|
468
|
+
if self.enable_storage_metrics:
|
469
|
+
self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
|
470
|
+
|
471
|
+
# release host memory
|
472
|
+
host_indices_list = []
|
473
|
+
for _ in range(n_release):
|
474
|
+
host_indices_list.append(cc.host_mem_release_queue.get())
|
475
|
+
if host_indices_list:
|
476
|
+
host_indices = torch.cat(host_indices_list, dim=0)
|
477
|
+
cc.mem_pool_host.free(host_indices)
|
426
478
|
|
427
479
|
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
428
480
|
can_terminate = True
|
@@ -430,9 +482,12 @@ class HiRadixCache(RadixCache):
|
|
430
482
|
if self.prefetch_stop_policy == "best_effort":
|
431
483
|
return can_terminate
|
432
484
|
|
433
|
-
|
434
|
-
|
435
|
-
|
485
|
+
if len(operation.hash_value) == 0:
|
486
|
+
completed = False
|
487
|
+
else:
|
488
|
+
completed = (
|
489
|
+
operation.completed_tokens == len(operation.hash_value) * self.page_size
|
490
|
+
)
|
436
491
|
|
437
492
|
if self.prefetch_stop_policy == "wait_complete":
|
438
493
|
can_terminate = completed
|
@@ -444,15 +499,22 @@ class HiRadixCache(RadixCache):
|
|
444
499
|
# unknown prefetch stop policy, just return True
|
445
500
|
return True
|
446
501
|
|
502
|
+
operation_terminated = operation.is_terminated()
|
447
503
|
if self.tp_world_size > 1:
|
448
|
-
|
504
|
+
states = torch.tensor(
|
505
|
+
[1 - int(can_terminate), int(operation_terminated)],
|
506
|
+
dtype=torch.int,
|
507
|
+
)
|
449
508
|
torch.distributed.all_reduce(
|
450
|
-
|
451
|
-
op=torch.distributed.ReduceOp.
|
509
|
+
states,
|
510
|
+
op=torch.distributed.ReduceOp.MAX,
|
452
511
|
group=self.tp_group,
|
453
512
|
)
|
454
|
-
can_terminate =
|
455
|
-
|
513
|
+
can_terminate = states[0].item() == 0
|
514
|
+
operation_terminated = states[1].item() == 1
|
515
|
+
# the operation should be terminated if it is already terminated on any TP worker
|
516
|
+
# or it meets the termination condition on all TP workers
|
517
|
+
can_terminate = can_terminate or operation_terminated
|
456
518
|
return can_terminate
|
457
519
|
|
458
520
|
def check_prefetch_progress(self, req_id: str) -> bool:
|
@@ -479,7 +541,7 @@ class HiRadixCache(RadixCache):
|
|
479
541
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
480
542
|
|
481
543
|
min_completed_tokens = completed_tokens
|
482
|
-
if self.tp_world_size > 1
|
544
|
+
if self.tp_world_size > 1:
|
483
545
|
# synchrnoize TP workers to make the same update to hiradix cache
|
484
546
|
completed_tokens_tensor = torch.tensor(
|
485
547
|
min_completed_tokens, dtype=torch.int
|
@@ -502,13 +564,18 @@ class HiRadixCache(RadixCache):
|
|
502
564
|
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
503
565
|
|
504
566
|
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
505
|
-
self.cache_controller.
|
567
|
+
self.cache_controller.append_host_mem_release(
|
506
568
|
host_indices[min_completed_tokens:completed_tokens]
|
507
569
|
)
|
508
570
|
last_host_node.release_host()
|
509
571
|
del self.ongoing_prefetch[req_id]
|
510
572
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
511
573
|
|
574
|
+
if self.enable_storage_metrics:
|
575
|
+
self.metrics_collector.log_prefetched_tokens(
|
576
|
+
min_completed_tokens - matched_length
|
577
|
+
)
|
578
|
+
|
512
579
|
return True
|
513
580
|
|
514
581
|
def match_prefix(self, key: List[int], **kwargs):
|
@@ -536,6 +603,8 @@ class HiRadixCache(RadixCache):
|
|
536
603
|
while last_node.evicted:
|
537
604
|
host_hit_length += len(last_node.host_value)
|
538
605
|
last_node = last_node.parent
|
606
|
+
while not last_host_node.backuped:
|
607
|
+
last_host_node = last_host_node.parent
|
539
608
|
|
540
609
|
return MatchResult(
|
541
610
|
device_indices=value,
|
@@ -556,7 +625,11 @@ class HiRadixCache(RadixCache):
|
|
556
625
|
len(new_input_tokens) % self.page_size
|
557
626
|
)
|
558
627
|
new_input_tokens = new_input_tokens[:prefetch_length]
|
559
|
-
if
|
628
|
+
if (
|
629
|
+
not self.enable_storage
|
630
|
+
or prefetch_length < self.prefetch_threshold
|
631
|
+
or self.cache_controller.prefetch_rate_limited()
|
632
|
+
):
|
560
633
|
return
|
561
634
|
|
562
635
|
last_host_node.protect_host()
|
@@ -564,6 +637,10 @@ class HiRadixCache(RadixCache):
|
|
564
637
|
if host_indices is None:
|
565
638
|
self.evict_host(prefetch_length)
|
566
639
|
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
640
|
+
if host_indices is None:
|
641
|
+
last_host_node.release_host()
|
642
|
+
# no sufficient host memory for prefetch
|
643
|
+
return
|
567
644
|
operation = self.cache_controller.prefetch(
|
568
645
|
req_id, host_indices, new_input_tokens, last_hash
|
569
646
|
)
|
@@ -642,7 +719,6 @@ class HiRadixCache(RadixCache):
|
|
642
719
|
new_node.parent = child.parent
|
643
720
|
new_node.lock_ref = child.lock_ref
|
644
721
|
new_node.key = child.key[:split_len]
|
645
|
-
new_node.loading = child.loading
|
646
722
|
new_node.hit_count = child.hit_count
|
647
723
|
|
648
724
|
# split value and host value if exists
|
@@ -663,11 +739,11 @@ class HiRadixCache(RadixCache):
|
|
663
739
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
664
740
|
return new_node
|
665
741
|
|
666
|
-
def
|
667
|
-
node.last_access_time = time.monotonic()
|
742
|
+
def insert(self, key: List, value, chunked=False):
|
668
743
|
if len(key) == 0:
|
669
744
|
return 0
|
670
745
|
|
746
|
+
node = self.root_node
|
671
747
|
child_key = self.get_child_key_fn(key)
|
672
748
|
total_prefix_length = 0
|
673
749
|
|
@@ -684,7 +760,7 @@ class HiRadixCache(RadixCache):
|
|
684
760
|
self.token_to_kv_pool_host.update_synced(node.host_value)
|
685
761
|
self.evictable_size_ += len(node.value)
|
686
762
|
else:
|
687
|
-
self.
|
763
|
+
self._inc_hit_count(node, chunked)
|
688
764
|
total_prefix_length += prefix_len
|
689
765
|
else:
|
690
766
|
# partial match, split the node
|
@@ -694,7 +770,7 @@ class HiRadixCache(RadixCache):
|
|
694
770
|
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
695
771
|
self.evictable_size_ += len(new_node.value)
|
696
772
|
else:
|
697
|
-
self.
|
773
|
+
self._inc_hit_count(new_node, chunked)
|
698
774
|
total_prefix_length += prefix_len
|
699
775
|
node = new_node
|
700
776
|
|
@@ -728,7 +804,7 @@ class HiRadixCache(RadixCache):
|
|
728
804
|
last_hash = new_node.hash_value[-1]
|
729
805
|
|
730
806
|
if self.cache_controller.write_policy != "write_back":
|
731
|
-
self.
|
807
|
+
self._inc_hit_count(new_node, chunked)
|
732
808
|
return total_prefix_length
|
733
809
|
|
734
810
|
def _collect_leaves_device(self):
|
@@ -755,3 +831,19 @@ class HiRadixCache(RadixCache):
|
|
755
831
|
if not cur_child.evicted:
|
756
832
|
stack.append(cur_child)
|
757
833
|
return ret_list
|
834
|
+
|
835
|
+
def release_aborted_request(self, rid: str):
|
836
|
+
if rid not in self.ongoing_prefetch:
|
837
|
+
return
|
838
|
+
|
839
|
+
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
|
840
|
+
if operation.host_indices is None:
|
841
|
+
return
|
842
|
+
|
843
|
+
completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
|
844
|
+
if self.tp_world_size > 1:
|
845
|
+
torch.distributed.barrier(group=self.tp_group)
|
846
|
+
last_host_node.release_host()
|
847
|
+
del self.ongoing_prefetch[rid]
|
848
|
+
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
|
849
|
+
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
|
|
183
183
|
self.req_to_token_pool.free(req.req_pool_idx)
|
184
184
|
self.dec_lock_ref(req.last_node)
|
185
185
|
|
186
|
-
def cache_unfinished_req(self, req: Req):
|
186
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
187
187
|
"""Cache request when it is unfinished."""
|
188
188
|
if self.disable:
|
189
189
|
return
|