sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
|
|
20
20
|
MLATokenToKVPoolHost,
|
21
21
|
)
|
22
22
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
23
|
+
from sglang.srt.metrics.collector import StorageMetricsCollector
|
23
24
|
|
24
25
|
logger = logging.getLogger(__name__)
|
25
26
|
|
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
|
|
37
38
|
hicache_write_policy: str,
|
38
39
|
hicache_io_backend: str,
|
39
40
|
hicache_mem_layout: str,
|
41
|
+
enable_metrics: bool,
|
40
42
|
hicache_storage_backend: Optional[str] = None,
|
41
43
|
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
42
44
|
model_name: Optional[str] = None,
|
@@ -73,6 +75,8 @@ class HiRadixCache(RadixCache):
|
|
73
75
|
self.tp_group = tp_cache_group
|
74
76
|
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
75
77
|
self.enable_storage = hicache_storage_backend is not None
|
78
|
+
self.enable_storage_metrics = self.enable_storage and enable_metrics
|
79
|
+
|
76
80
|
# todo: customizable storage prefetch threshold and timeout
|
77
81
|
self.prefetch_threshold = 256
|
78
82
|
self.prefetch_timeout = 3 # seconds
|
@@ -92,6 +96,14 @@ class HiRadixCache(RadixCache):
|
|
92
96
|
model_name=model_name,
|
93
97
|
storage_backend_extra_config=storage_backend_extra_config,
|
94
98
|
)
|
99
|
+
if self.enable_storage_metrics:
|
100
|
+
# TODO: support pp
|
101
|
+
labels = {
|
102
|
+
"storage_backend": hicache_storage_backend,
|
103
|
+
"tp_rank": self.cache_controller.tp_rank,
|
104
|
+
"dp_rank": self.cache_controller.dp_rank,
|
105
|
+
}
|
106
|
+
self.metrics_collector = StorageMetricsCollector(labels=labels)
|
95
107
|
|
96
108
|
# record the nodes with ongoing write through
|
97
109
|
self.ongoing_write_through = {}
|
@@ -102,10 +114,7 @@ class HiRadixCache(RadixCache):
|
|
102
114
|
self.ongoing_backup = {}
|
103
115
|
# todo: dynamically adjust the threshold
|
104
116
|
self.write_through_threshold = (
|
105
|
-
1 if hicache_write_policy == "write_through" else
|
106
|
-
)
|
107
|
-
self.write_through_threshold_storage = (
|
108
|
-
1 if hicache_write_policy == "write_through" else 3
|
117
|
+
1 if hicache_write_policy == "write_through" else 2
|
109
118
|
)
|
110
119
|
self.load_back_threshold = 10
|
111
120
|
super().__init__(
|
@@ -125,6 +134,28 @@ class HiRadixCache(RadixCache):
|
|
125
134
|
height += 1
|
126
135
|
return height
|
127
136
|
|
137
|
+
def clear_storage_backend(self) -> bool:
|
138
|
+
if self.enable_storage:
|
139
|
+
try:
|
140
|
+
# Check if the storage backend has a clear method (for nixl backends)
|
141
|
+
if hasattr(self.cache_controller.storage_backend, "clear"):
|
142
|
+
self.cache_controller.storage_backend.clear()
|
143
|
+
logger.info(
|
144
|
+
"Hierarchical cache storage backend cleared successfully!"
|
145
|
+
)
|
146
|
+
return True
|
147
|
+
else:
|
148
|
+
logger.warning(
|
149
|
+
f"Storage backend {type(self.cache_controller.storage_backend).__name__} does not support clear operation."
|
150
|
+
)
|
151
|
+
return False
|
152
|
+
except Exception as e:
|
153
|
+
logger.error(f"Failed to clear hierarchical cache storage backend: {e}")
|
154
|
+
return False
|
155
|
+
else:
|
156
|
+
logger.warning("Hierarchical cache storage backend is not enabled.")
|
157
|
+
return False
|
158
|
+
|
128
159
|
def write_backup(self, node: TreeNode, write_back=False):
|
129
160
|
host_indices = self.cache_controller.write(
|
130
161
|
device_indices=node.value,
|
@@ -155,8 +186,9 @@ class HiRadixCache(RadixCache):
|
|
155
186
|
self.ongoing_backup[operation_id] = node
|
156
187
|
node.protect_host()
|
157
188
|
|
158
|
-
def
|
159
|
-
|
189
|
+
def _inc_hit_count(self, node: TreeNode, chunked=False):
|
190
|
+
# skip the hit count update for chunked requests
|
191
|
+
if self.cache_controller.write_policy == "write_back" or chunked:
|
160
192
|
return
|
161
193
|
node.hit_count += 1
|
162
194
|
|
@@ -164,51 +196,62 @@ class HiRadixCache(RadixCache):
|
|
164
196
|
if node.hit_count >= self.write_through_threshold:
|
165
197
|
# write to host if the node is not backuped
|
166
198
|
self.write_backup(node)
|
167
|
-
else:
|
168
|
-
if (
|
169
|
-
self.enable_storage
|
170
|
-
and (not node.backuped_storage)
|
171
|
-
and node.hit_count >= self.write_through_threshold_storage
|
172
|
-
):
|
173
|
-
# if the node is backuped on host memory but not on storage
|
174
|
-
self.write_backup_storage(node)
|
175
199
|
|
176
200
|
def writing_check(self, write_back=False):
|
177
201
|
if write_back:
|
178
202
|
# blocking till all write back complete
|
179
203
|
while len(self.ongoing_write_through) > 0:
|
180
|
-
|
181
|
-
|
204
|
+
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
205
|
+
finish_event.synchronize()
|
206
|
+
for ack_id in ack_list:
|
207
|
+
del self.ongoing_write_through[ack_id]
|
208
|
+
self.cache_controller.ack_write_queue.clear()
|
209
|
+
assert len(self.ongoing_write_through) == 0
|
182
210
|
return
|
183
|
-
|
184
|
-
|
185
|
-
)
|
211
|
+
|
212
|
+
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
|
213
|
+
if len(self.ongoing_write_through) == 0:
|
214
|
+
return
|
215
|
+
|
216
|
+
finish_count = 0
|
217
|
+
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
218
|
+
if not finish_event.query():
|
219
|
+
break
|
220
|
+
finish_count += 1
|
221
|
+
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
|
186
222
|
if self.tp_world_size > 1:
|
187
|
-
#
|
223
|
+
# synchronize TP workers to make the same update to radix cache
|
188
224
|
torch.distributed.all_reduce(
|
189
225
|
queue_size,
|
190
226
|
op=torch.distributed.ReduceOp.MIN,
|
191
227
|
group=self.tp_group,
|
192
228
|
)
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
229
|
+
|
230
|
+
finish_count = int(queue_size.item())
|
231
|
+
while finish_count > 0:
|
232
|
+
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
233
|
+
finish_event.synchronize()
|
234
|
+
for ack_id in ack_list:
|
235
|
+
backuped_node = self.ongoing_write_through.pop(ack_id)
|
236
|
+
self.dec_lock_ref(backuped_node)
|
237
|
+
if self.enable_storage:
|
238
|
+
self.write_backup_storage(backuped_node)
|
239
|
+
finish_count -= 1
|
197
240
|
|
198
241
|
def loading_check(self):
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
self.dec_lock_ref(end_node)
|
204
|
-
while end_node != start_node:
|
205
|
-
assert end_node.loading
|
206
|
-
end_node.loading = False
|
207
|
-
end_node = end_node.parent
|
208
|
-
# clear the reference
|
209
|
-
del self.ongoing_load_back[ack_id]
|
210
|
-
except Exception:
|
242
|
+
finish_count = 0
|
243
|
+
for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
|
244
|
+
if not finish_event.query():
|
245
|
+
# the KV cache loading is still ongoing
|
211
246
|
break
|
247
|
+
finish_count += 1
|
248
|
+
# no need to sync across TP workers as batch forwarding is synced
|
249
|
+
for ack_id in ack_list:
|
250
|
+
end_node = self.ongoing_load_back.pop(ack_id)
|
251
|
+
self.dec_lock_ref(end_node)
|
252
|
+
|
253
|
+
# ACK until all events are processed
|
254
|
+
del self.cache_controller.ack_load_queue[:finish_count]
|
212
255
|
|
213
256
|
def evictable_size(self):
|
214
257
|
return self.evictable_size_
|
@@ -333,12 +376,11 @@ class HiRadixCache(RadixCache):
|
|
333
376
|
# no sufficient GPU memory to load back KV caches
|
334
377
|
return None
|
335
378
|
|
336
|
-
self.ongoing_load_back[last_hit_node.id] =
|
379
|
+
self.ongoing_load_back[last_hit_node.id] = last_hit_node
|
337
380
|
offset = 0
|
338
381
|
for node in nodes_to_load:
|
339
382
|
node.value = device_indices[offset : offset + len(node.host_value)]
|
340
383
|
offset += len(node.host_value)
|
341
|
-
node.loading = True
|
342
384
|
self.evictable_size_ += len(device_indices)
|
343
385
|
self.inc_lock_ref(last_hit_node)
|
344
386
|
|
@@ -367,66 +409,72 @@ class HiRadixCache(RadixCache):
|
|
367
409
|
last_node,
|
368
410
|
)
|
369
411
|
|
370
|
-
def ready_to_load_host_cache(self):
|
371
|
-
|
372
|
-
|
373
|
-
|
412
|
+
def ready_to_load_host_cache(self) -> int:
|
413
|
+
"""
|
414
|
+
Notify the cache controller to start the KV cache loading.
|
415
|
+
Return the consumer index for the schedule batch manager to track.
|
416
|
+
"""
|
417
|
+
return self.cache_controller.start_loading()
|
374
418
|
|
375
419
|
def check_hicache_events(self):
|
376
420
|
self.writing_check()
|
377
421
|
self.loading_check()
|
378
422
|
if self.enable_storage:
|
379
|
-
self.
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
queue_size = torch.tensor(
|
384
|
-
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
|
385
|
-
)
|
386
|
-
if self.tp_world_size > 1:
|
387
|
-
# synchrnoize TP workers to make the same update to hiradix cache
|
388
|
-
torch.distributed.all_reduce(
|
389
|
-
queue_size,
|
390
|
-
op=torch.distributed.ReduceOp.MIN,
|
391
|
-
group=self.tp_group,
|
423
|
+
self.drain_storage_control_queues()
|
424
|
+
if self.enable_storage_metrics:
|
425
|
+
self.metrics_collector.log_storage_metrics(
|
426
|
+
self.cache_controller.storage_backend.get_stats()
|
392
427
|
)
|
393
|
-
for _ in range(queue_size.item()):
|
394
|
-
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
395
|
-
if req_id in self.ongoing_prefetch:
|
396
|
-
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
|
397
|
-
last_host_node.release_host()
|
398
|
-
del self.ongoing_prefetch[req_id]
|
399
|
-
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
400
|
-
else:
|
401
|
-
# the revoked operation already got terminated
|
402
|
-
pass
|
403
428
|
|
404
|
-
def
|
405
|
-
|
406
|
-
|
429
|
+
def drain_storage_control_queues(self):
|
430
|
+
"""
|
431
|
+
Combine prefetch revoke, backup ack, and host mem release checks
|
432
|
+
to minimize TP synchronization and Python overhead.
|
433
|
+
"""
|
434
|
+
cc = self.cache_controller
|
435
|
+
|
436
|
+
qsizes = torch.tensor(
|
437
|
+
[
|
438
|
+
cc.prefetch_revoke_queue.qsize(),
|
439
|
+
cc.ack_backup_queue.qsize(),
|
440
|
+
cc.host_mem_release_queue.qsize(),
|
441
|
+
],
|
442
|
+
dtype=torch.int,
|
407
443
|
)
|
408
444
|
if self.tp_world_size > 1:
|
409
|
-
# synchrnoize TP workers to make the same update to hiradix cache
|
410
445
|
torch.distributed.all_reduce(
|
411
|
-
|
412
|
-
op=torch.distributed.ReduceOp.MIN,
|
413
|
-
group=self.tp_group,
|
446
|
+
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
414
447
|
)
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
448
|
+
|
449
|
+
n_revoke, n_backup, n_release = map(int, qsizes.tolist())
|
450
|
+
|
451
|
+
# process prefetch revokes
|
452
|
+
for _ in range(n_revoke):
|
453
|
+
req_id = cc.prefetch_revoke_queue.get()
|
454
|
+
info = self.ongoing_prefetch.pop(req_id, None)
|
455
|
+
if info is not None:
|
456
|
+
last_host_node, token_ids, _, _ = info
|
457
|
+
last_host_node.release_host()
|
458
|
+
cc.prefetch_tokens_occupied -= len(token_ids)
|
459
|
+
# else: the revoked operation already got terminated, nothing to do
|
460
|
+
|
461
|
+
# process backup acks
|
462
|
+
for _ in range(n_backup):
|
463
|
+
operation = cc.ack_backup_queue.get()
|
464
|
+
ack_id = operation.id
|
465
|
+
entry = self.ongoing_backup.pop(ack_id, None)
|
466
|
+
if entry is not None:
|
467
|
+
entry.release_host()
|
468
|
+
if self.enable_storage_metrics:
|
469
|
+
self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
|
470
|
+
|
471
|
+
# release host memory
|
472
|
+
host_indices_list = []
|
473
|
+
for _ in range(n_release):
|
474
|
+
host_indices_list.append(cc.host_mem_release_queue.get())
|
475
|
+
if host_indices_list:
|
476
|
+
host_indices = torch.cat(host_indices_list, dim=0)
|
477
|
+
cc.mem_pool_host.free(host_indices)
|
430
478
|
|
431
479
|
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
432
480
|
can_terminate = True
|
@@ -451,15 +499,22 @@ class HiRadixCache(RadixCache):
|
|
451
499
|
# unknown prefetch stop policy, just return True
|
452
500
|
return True
|
453
501
|
|
502
|
+
operation_terminated = operation.is_terminated()
|
454
503
|
if self.tp_world_size > 1:
|
455
|
-
|
504
|
+
states = torch.tensor(
|
505
|
+
[1 - int(can_terminate), int(operation_terminated)],
|
506
|
+
dtype=torch.int,
|
507
|
+
)
|
456
508
|
torch.distributed.all_reduce(
|
457
|
-
|
458
|
-
op=torch.distributed.ReduceOp.
|
509
|
+
states,
|
510
|
+
op=torch.distributed.ReduceOp.MAX,
|
459
511
|
group=self.tp_group,
|
460
512
|
)
|
461
|
-
can_terminate =
|
462
|
-
|
513
|
+
can_terminate = states[0].item() == 0
|
514
|
+
operation_terminated = states[1].item() == 1
|
515
|
+
# the operation should be terminated if it is already terminated on any TP worker
|
516
|
+
# or it meets the termination condition on all TP workers
|
517
|
+
can_terminate = can_terminate or operation_terminated
|
463
518
|
return can_terminate
|
464
519
|
|
465
520
|
def check_prefetch_progress(self, req_id: str) -> bool:
|
@@ -486,7 +541,7 @@ class HiRadixCache(RadixCache):
|
|
486
541
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
487
542
|
|
488
543
|
min_completed_tokens = completed_tokens
|
489
|
-
if self.tp_world_size > 1
|
544
|
+
if self.tp_world_size > 1:
|
490
545
|
# synchrnoize TP workers to make the same update to hiradix cache
|
491
546
|
completed_tokens_tensor = torch.tensor(
|
492
547
|
min_completed_tokens, dtype=torch.int
|
@@ -509,13 +564,18 @@ class HiRadixCache(RadixCache):
|
|
509
564
|
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
510
565
|
|
511
566
|
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
512
|
-
self.cache_controller.
|
567
|
+
self.cache_controller.append_host_mem_release(
|
513
568
|
host_indices[min_completed_tokens:completed_tokens]
|
514
569
|
)
|
515
570
|
last_host_node.release_host()
|
516
571
|
del self.ongoing_prefetch[req_id]
|
517
572
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
518
573
|
|
574
|
+
if self.enable_storage_metrics:
|
575
|
+
self.metrics_collector.log_prefetched_tokens(
|
576
|
+
min_completed_tokens - matched_length
|
577
|
+
)
|
578
|
+
|
519
579
|
return True
|
520
580
|
|
521
581
|
def match_prefix(self, key: List[int], **kwargs):
|
@@ -565,7 +625,11 @@ class HiRadixCache(RadixCache):
|
|
565
625
|
len(new_input_tokens) % self.page_size
|
566
626
|
)
|
567
627
|
new_input_tokens = new_input_tokens[:prefetch_length]
|
568
|
-
if
|
628
|
+
if (
|
629
|
+
not self.enable_storage
|
630
|
+
or prefetch_length < self.prefetch_threshold
|
631
|
+
or self.cache_controller.prefetch_rate_limited()
|
632
|
+
):
|
569
633
|
return
|
570
634
|
|
571
635
|
last_host_node.protect_host()
|
@@ -573,6 +637,10 @@ class HiRadixCache(RadixCache):
|
|
573
637
|
if host_indices is None:
|
574
638
|
self.evict_host(prefetch_length)
|
575
639
|
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
640
|
+
if host_indices is None:
|
641
|
+
last_host_node.release_host()
|
642
|
+
# no sufficient host memory for prefetch
|
643
|
+
return
|
576
644
|
operation = self.cache_controller.prefetch(
|
577
645
|
req_id, host_indices, new_input_tokens, last_hash
|
578
646
|
)
|
@@ -651,7 +719,6 @@ class HiRadixCache(RadixCache):
|
|
651
719
|
new_node.parent = child.parent
|
652
720
|
new_node.lock_ref = child.lock_ref
|
653
721
|
new_node.key = child.key[:split_len]
|
654
|
-
new_node.loading = child.loading
|
655
722
|
new_node.hit_count = child.hit_count
|
656
723
|
|
657
724
|
# split value and host value if exists
|
@@ -672,11 +739,11 @@ class HiRadixCache(RadixCache):
|
|
672
739
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
673
740
|
return new_node
|
674
741
|
|
675
|
-
def
|
676
|
-
node.last_access_time = time.monotonic()
|
742
|
+
def insert(self, key: List, value, chunked=False):
|
677
743
|
if len(key) == 0:
|
678
744
|
return 0
|
679
745
|
|
746
|
+
node = self.root_node
|
680
747
|
child_key = self.get_child_key_fn(key)
|
681
748
|
total_prefix_length = 0
|
682
749
|
|
@@ -693,7 +760,7 @@ class HiRadixCache(RadixCache):
|
|
693
760
|
self.token_to_kv_pool_host.update_synced(node.host_value)
|
694
761
|
self.evictable_size_ += len(node.value)
|
695
762
|
else:
|
696
|
-
self.
|
763
|
+
self._inc_hit_count(node, chunked)
|
697
764
|
total_prefix_length += prefix_len
|
698
765
|
else:
|
699
766
|
# partial match, split the node
|
@@ -703,7 +770,7 @@ class HiRadixCache(RadixCache):
|
|
703
770
|
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
704
771
|
self.evictable_size_ += len(new_node.value)
|
705
772
|
else:
|
706
|
-
self.
|
773
|
+
self._inc_hit_count(new_node, chunked)
|
707
774
|
total_prefix_length += prefix_len
|
708
775
|
node = new_node
|
709
776
|
|
@@ -737,7 +804,7 @@ class HiRadixCache(RadixCache):
|
|
737
804
|
last_hash = new_node.hash_value[-1]
|
738
805
|
|
739
806
|
if self.cache_controller.write_policy != "write_back":
|
740
|
-
self.
|
807
|
+
self._inc_hit_count(new_node, chunked)
|
741
808
|
return total_prefix_length
|
742
809
|
|
743
810
|
def _collect_leaves_device(self):
|
@@ -764,3 +831,19 @@ class HiRadixCache(RadixCache):
|
|
764
831
|
if not cur_child.evicted:
|
765
832
|
stack.append(cur_child)
|
766
833
|
return ret_list
|
834
|
+
|
835
|
+
def release_aborted_request(self, rid: str):
|
836
|
+
if rid not in self.ongoing_prefetch:
|
837
|
+
return
|
838
|
+
|
839
|
+
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
|
840
|
+
if operation.host_indices is None:
|
841
|
+
return
|
842
|
+
|
843
|
+
completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
|
844
|
+
if self.tp_world_size > 1:
|
845
|
+
torch.distributed.barrier(group=self.tp_group)
|
846
|
+
last_host_node.release_host()
|
847
|
+
del self.ongoing_prefetch[rid]
|
848
|
+
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
|
849
|
+
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
|
|
183
183
|
self.req_to_token_pool.free(req.req_pool_idx)
|
184
184
|
self.dec_lock_ref(req.last_node)
|
185
185
|
|
186
|
-
def cache_unfinished_req(self, req: Req):
|
186
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
187
187
|
"""Cache request when it is unfinished."""
|
188
188
|
if self.disable:
|
189
189
|
return
|