sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -574,7 +574,7 @@ class TokenizerManager:
|
|
574
574
|
"The server is not configured to enable custom logit processor. "
|
575
575
|
"Please set `--enable-custom-logits-processor` to enable this feature."
|
576
576
|
)
|
577
|
-
if self.server_args.
|
577
|
+
if self.server_args.enable_lora and obj.lora_path:
|
578
578
|
self._validate_lora_adapters(obj)
|
579
579
|
|
580
580
|
def _validate_input_ids_in_vocab(
|
@@ -604,7 +604,7 @@ class TokenizerManager:
|
|
604
604
|
sampling_kwargs = obj.sampling_params
|
605
605
|
sampling_params = SamplingParams(**sampling_kwargs)
|
606
606
|
sampling_params.normalize(self.tokenizer)
|
607
|
-
sampling_params.verify()
|
607
|
+
sampling_params.verify(self.model_config.vocab_size)
|
608
608
|
|
609
609
|
# Build return object
|
610
610
|
if isinstance(obj, GenerateReqInput):
|
@@ -1037,6 +1037,10 @@ class TokenizerManager:
|
|
1037
1037
|
_: Optional[fastapi.Request] = None,
|
1038
1038
|
) -> LoadLoRAAdapterReqOutput:
|
1039
1039
|
self.auto_create_handle_loop()
|
1040
|
+
if not self.server_args.enable_lora:
|
1041
|
+
raise ValueError(
|
1042
|
+
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1043
|
+
)
|
1040
1044
|
|
1041
1045
|
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1042
1046
|
# with dp_size > 1.
|
@@ -1060,6 +1064,10 @@ class TokenizerManager:
|
|
1060
1064
|
_: Optional[fastapi.Request] = None,
|
1061
1065
|
) -> UnloadLoRAAdapterReqOutput:
|
1062
1066
|
self.auto_create_handle_loop()
|
1067
|
+
if not self.server_args.enable_lora:
|
1068
|
+
raise ValueError(
|
1069
|
+
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1070
|
+
)
|
1063
1071
|
|
1064
1072
|
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1065
1073
|
# with dp_size > 1.
|
@@ -1359,7 +1367,7 @@ class TokenizerManager:
|
|
1359
1367
|
while True:
|
1360
1368
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
1361
1369
|
self._result_dispatcher(recv_obj)
|
1362
|
-
self.last_receive_tstamp = time.
|
1370
|
+
self.last_receive_tstamp = time.perf_counter()
|
1363
1371
|
|
1364
1372
|
def _handle_batch_output(
|
1365
1373
|
self,
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -174,6 +174,20 @@ class TpModelWorker:
|
|
174
174
|
self.model_runner.token_to_kv_pool.size,
|
175
175
|
)
|
176
176
|
|
177
|
+
@property
|
178
|
+
def sliding_window_size(self) -> Optional[int]:
|
179
|
+
return self.model_runner.sliding_window_size
|
180
|
+
|
181
|
+
@property
|
182
|
+
def is_hybrid(self) -> bool:
|
183
|
+
return self.model_runner.is_hybrid is not None
|
184
|
+
|
185
|
+
def get_tokens_per_layer_info(self):
|
186
|
+
return (
|
187
|
+
self.model_runner.full_max_total_num_tokens,
|
188
|
+
self.model_runner.swa_max_total_num_tokens,
|
189
|
+
)
|
190
|
+
|
177
191
|
def get_pad_input_ids_func(self):
|
178
192
|
return getattr(self.model_runner.model, "pad_input_ids", None)
|
179
193
|
|
@@ -102,6 +102,17 @@ class TpModelWorkerClient:
|
|
102
102
|
def get_worker_info(self):
|
103
103
|
return self.worker.get_worker_info()
|
104
104
|
|
105
|
+
def get_tokens_per_layer_info(self):
|
106
|
+
return self.worker.get_tokens_per_layer_info()
|
107
|
+
|
108
|
+
@property
|
109
|
+
def sliding_window_size(self) -> Optional[int]:
|
110
|
+
return self.worker.sliding_window_size
|
111
|
+
|
112
|
+
@property
|
113
|
+
def is_hybrid(self) -> bool:
|
114
|
+
return self.worker.is_hybrid
|
115
|
+
|
105
116
|
def get_pad_input_ids_func(self):
|
106
117
|
return self.worker.get_pad_input_ids_func()
|
107
118
|
|
@@ -57,11 +57,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
57
57
|
def debug_print(self) -> str:
|
58
58
|
return ""
|
59
59
|
|
60
|
-
def log_usage(self, evictable_size: int = 0):
|
61
|
-
num_used = self.size - (self.available_size() + evictable_size)
|
62
|
-
msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, "
|
63
|
-
return msg, num_used
|
64
|
-
|
65
60
|
def available_size(self):
|
66
61
|
return len(self.free_pages) * self.page_size
|
67
62
|
|
@@ -190,7 +185,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
190
185
|
self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping
|
191
186
|
|
192
187
|
def available_size(self):
|
193
|
-
|
188
|
+
raise NotImplementedError()
|
194
189
|
|
195
190
|
def full_available_size(self):
|
196
191
|
return self.full_attn_allocator.available_size()
|
@@ -214,16 +209,6 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
214
209
|
)
|
215
210
|
return msg
|
216
211
|
|
217
|
-
def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0):
|
218
|
-
used_full = self.size_full - (self.full_available_size() + full_evictable_size)
|
219
|
-
used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size)
|
220
|
-
msg = (
|
221
|
-
f"#token: full={used_full}, swa={used_swa}, "
|
222
|
-
f"token usage: full={used_full / self.size_full:.2f}, "
|
223
|
-
f"swa={used_swa / self.size_swa:.2f}, "
|
224
|
-
)
|
225
|
-
return msg, used_full
|
226
|
-
|
227
212
|
def get_kvcache(self):
|
228
213
|
return self._kvcache
|
229
214
|
|
@@ -541,6 +526,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
541
526
|
self.is_not_in_free_group = True
|
542
527
|
self.free_group = []
|
543
528
|
|
529
|
+
def get_cpu_copy(self, indices):
|
530
|
+
return self._kvcache.get_cpu_copy(indices)
|
531
|
+
|
532
|
+
def load_cpu_copy(self, kv_cache_cpu, indices):
|
533
|
+
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
534
|
+
|
544
535
|
|
545
536
|
def alloc_extend_kernel_ascend(
|
546
537
|
prefix_lens,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
|
2
|
+
from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
5
|
|
@@ -56,15 +56,27 @@ class BasePrefixCache(ABC):
|
|
56
56
|
pass
|
57
57
|
|
58
58
|
@abstractmethod
|
59
|
-
def dec_lock_ref(self, node: Any):
|
59
|
+
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
|
60
60
|
pass
|
61
61
|
|
62
62
|
def evictable_size(self):
|
63
63
|
return 0
|
64
64
|
|
65
|
+
def full_evictable_size(self):
|
66
|
+
return 0
|
67
|
+
|
68
|
+
def swa_evictable_size(self):
|
69
|
+
return 0
|
70
|
+
|
65
71
|
def protected_size(self):
|
66
72
|
return 0
|
67
73
|
|
74
|
+
def full_protected_size(self):
|
75
|
+
return 0
|
76
|
+
|
77
|
+
def swa_protected_size(self):
|
78
|
+
return 0
|
79
|
+
|
68
80
|
def total_size(self):
|
69
81
|
raise NotImplementedError()
|
70
82
|
|
@@ -61,7 +61,7 @@ class ChunkCache(BasePrefixCache):
|
|
61
61
|
def inc_lock_ref(self, node: Any):
|
62
62
|
return 0
|
63
63
|
|
64
|
-
def dec_lock_ref(self, node: Any):
|
64
|
+
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
|
65
65
|
return 0
|
66
66
|
|
67
67
|
def pretty_print(self):
|
@@ -80,7 +80,7 @@ class SWAChunkCache(ChunkCache):
|
|
80
80
|
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
|
81
81
|
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
|
82
82
|
|
83
|
-
def
|
83
|
+
def evict_swa(
|
84
84
|
self,
|
85
85
|
req: Req,
|
86
86
|
prelen: int,
|
@@ -95,3 +95,6 @@ class SWAChunkCache(ChunkCache):
|
|
95
95
|
]
|
96
96
|
self.token_to_kv_pool_allocator.free_swa(free_slots)
|
97
97
|
req.evicted_seqlen_local = new_evicted_seqlen_local
|
98
|
+
|
99
|
+
def evict(self, num_tokens: int):
|
100
|
+
pass
|
@@ -0,0 +1,152 @@
|
|
1
|
+
import hashlib
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import List, Optional
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
|
13
|
+
hasher = hashlib.sha256()
|
14
|
+
|
15
|
+
if prior_hash:
|
16
|
+
hasher.update(bytes.fromhex(prior_hash))
|
17
|
+
|
18
|
+
for t in token_ids:
|
19
|
+
hasher.update(t.to_bytes(4, byteorder="little", signed=False))
|
20
|
+
|
21
|
+
return hasher.hexdigest()
|
22
|
+
|
23
|
+
|
24
|
+
class HiCacheStorage(ABC):
|
25
|
+
"""
|
26
|
+
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
|
27
|
+
It abstracts the underlying storage mechanism, allowing different implementations to be used.
|
28
|
+
"""
|
29
|
+
|
30
|
+
# todo, translate tensor object access for different TP ranks
|
31
|
+
# potentially pass model and TP configs into storage backend
|
32
|
+
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
|
33
|
+
|
34
|
+
@abstractmethod
|
35
|
+
def get(
|
36
|
+
self, key: str, target_location: Optional[torch.Tensor] = None
|
37
|
+
) -> torch.Tensor | None:
|
38
|
+
"""
|
39
|
+
Retrieve the value associated with the given key.
|
40
|
+
Returns None if the key does not exist.
|
41
|
+
"""
|
42
|
+
pass
|
43
|
+
|
44
|
+
@abstractmethod
|
45
|
+
def batch_get(
|
46
|
+
self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None
|
47
|
+
) -> List[torch.Tensor | None]:
|
48
|
+
"""
|
49
|
+
Retrieve values for multiple keys.
|
50
|
+
Returns a list of tensors or None for each key.
|
51
|
+
"""
|
52
|
+
pass
|
53
|
+
|
54
|
+
@abstractmethod
|
55
|
+
def set(self, key, value) -> bool:
|
56
|
+
"""
|
57
|
+
Store the value associated with the given key.
|
58
|
+
Returns True if the operation was successful, False otherwise.
|
59
|
+
"""
|
60
|
+
pass
|
61
|
+
|
62
|
+
@abstractmethod
|
63
|
+
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
64
|
+
"""
|
65
|
+
Store multiple key-value pairs.
|
66
|
+
Returns True if all operations were successful, False otherwise.
|
67
|
+
"""
|
68
|
+
pass
|
69
|
+
|
70
|
+
@abstractmethod
|
71
|
+
def exists(self, key: str) -> bool:
|
72
|
+
"""
|
73
|
+
Check if the key exists in the storage.
|
74
|
+
Returns True if the key exists, False otherwise.
|
75
|
+
"""
|
76
|
+
pass
|
77
|
+
|
78
|
+
|
79
|
+
class HiCacheFile(HiCacheStorage):
|
80
|
+
|
81
|
+
def __init__(self, file_path: str = "/tmp/hicache"):
|
82
|
+
self.file_path = file_path
|
83
|
+
if not os.path.exists(self.file_path):
|
84
|
+
os.makedirs(self.file_path)
|
85
|
+
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
86
|
+
|
87
|
+
def get(
|
88
|
+
self, key: str, target_location: Optional[torch.Tensor] = None
|
89
|
+
) -> torch.Tensor | None:
|
90
|
+
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
91
|
+
try:
|
92
|
+
# todo: fixing the target_location logic to enable in-place loading
|
93
|
+
loaded_tensor = torch.load(tensor_path)
|
94
|
+
if isinstance(loaded_tensor, torch.Tensor):
|
95
|
+
return loaded_tensor
|
96
|
+
else:
|
97
|
+
logger.error(f"Loaded data for key {key} is not a tensor.")
|
98
|
+
return None
|
99
|
+
except FileNotFoundError:
|
100
|
+
return None
|
101
|
+
|
102
|
+
def batch_get(
|
103
|
+
self,
|
104
|
+
keys: List[str],
|
105
|
+
target_locations: Optional[List[torch.Tensor]] = None,
|
106
|
+
) -> List[torch.Tensor | None]:
|
107
|
+
return [
|
108
|
+
self.get(key, target_location)
|
109
|
+
for key, target_location in zip(
|
110
|
+
keys, target_locations or [None] * len(keys)
|
111
|
+
)
|
112
|
+
]
|
113
|
+
|
114
|
+
def set(self, key: str, value: torch.Tensor) -> bool:
|
115
|
+
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
116
|
+
if self.exists(key):
|
117
|
+
logger.debug(f"Key {key} already exists. Skipped.")
|
118
|
+
return True
|
119
|
+
try:
|
120
|
+
torch.save(value, tensor_path)
|
121
|
+
return True
|
122
|
+
except Exception as e:
|
123
|
+
logger.error(f"Failed to save tensor {key}: {e}")
|
124
|
+
return False
|
125
|
+
|
126
|
+
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
127
|
+
for key, value in zip(keys, values):
|
128
|
+
if not self.set(key, value):
|
129
|
+
return False
|
130
|
+
return True
|
131
|
+
|
132
|
+
def exists(self, key: str) -> bool:
|
133
|
+
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
134
|
+
return os.path.exists(tensor_path)
|
135
|
+
|
136
|
+
def delete(self, key: str) -> None:
|
137
|
+
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
138
|
+
try:
|
139
|
+
os.remove(tensor_path)
|
140
|
+
except FileNotFoundError:
|
141
|
+
logger.warning(f"Key {key} does not exist. Cannot delete.")
|
142
|
+
return
|
143
|
+
|
144
|
+
def clear(self) -> None:
|
145
|
+
try:
|
146
|
+
for filename in os.listdir(self.file_path):
|
147
|
+
file_path = os.path.join(self.file_path, filename)
|
148
|
+
if os.path.isfile(file_path):
|
149
|
+
os.remove(file_path)
|
150
|
+
logger.info("Cleared all entries in HiCacheFile storage.")
|
151
|
+
except Exception as e:
|
152
|
+
logger.error(f"Failed to clear HiCacheFile storage: {e}")
|
@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
|
|
35
35
|
hicache_size: int,
|
36
36
|
hicache_write_policy: str,
|
37
37
|
hicache_io_backend: str,
|
38
|
+
hicache_storage_backend: Optional[str] = None,
|
38
39
|
):
|
39
40
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
40
41
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
@@ -49,6 +50,9 @@ class HiRadixCache(RadixCache):
|
|
49
50
|
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
50
51
|
|
51
52
|
self.tp_group = tp_cache_group
|
53
|
+
self.enable_storage = hicache_storage_backend is not None
|
54
|
+
# todo: customizable storage prefetch threshold
|
55
|
+
self.prefetch_threshold = 256
|
52
56
|
|
53
57
|
self.load_cache_event = threading.Event()
|
54
58
|
self.cache_controller = HiCacheController(
|
@@ -58,16 +62,22 @@ class HiRadixCache(RadixCache):
|
|
58
62
|
load_cache_event=self.load_cache_event,
|
59
63
|
write_policy=hicache_write_policy,
|
60
64
|
io_backend=hicache_io_backend,
|
65
|
+
storage_backend=hicache_storage_backend,
|
66
|
+
prefetch_threshold=self.prefetch_threshold,
|
61
67
|
)
|
62
68
|
|
63
69
|
# record the nodes with ongoing write through
|
64
70
|
self.ongoing_write_through = {}
|
65
71
|
# record the node segments with ongoing load back
|
66
72
|
self.ongoing_load_back = {}
|
73
|
+
# record the ongoing prefetch requests
|
74
|
+
self.ongoing_prefetch = {}
|
75
|
+
self.ongoing_backup = {}
|
67
76
|
# todo: dynamically adjust the threshold
|
68
77
|
self.write_through_threshold = (
|
69
78
|
1 if hicache_write_policy == "write_through" else 3
|
70
79
|
)
|
80
|
+
self.write_through_threshold_storage = 3
|
71
81
|
self.load_back_threshold = 10
|
72
82
|
super().__init__(
|
73
83
|
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
@@ -108,13 +118,30 @@ class HiRadixCache(RadixCache):
|
|
108
118
|
|
109
119
|
return len(host_indices)
|
110
120
|
|
121
|
+
def write_backup_storage(self, node: TreeNode):
|
122
|
+
operation_id = self.cache_controller.write_storage(
|
123
|
+
node.host_value, node.key, node.parent.get_last_hash_value()
|
124
|
+
)
|
125
|
+
self.ongoing_backup[operation_id] = node
|
126
|
+
node.protect_host()
|
127
|
+
|
111
128
|
def inc_hit_count(self, node: TreeNode):
|
112
|
-
if
|
129
|
+
if self.cache_controller.write_policy == "write_back":
|
113
130
|
return
|
114
131
|
node.hit_count += 1
|
115
|
-
|
116
|
-
|
117
|
-
node.hit_count
|
132
|
+
|
133
|
+
if not node.backuped:
|
134
|
+
if node.hit_count >= self.write_through_threshold:
|
135
|
+
# write to host if the node is not backuped
|
136
|
+
self.write_backup(node)
|
137
|
+
else:
|
138
|
+
if (
|
139
|
+
self.enable_storage
|
140
|
+
and (not node.backuped_storage)
|
141
|
+
and node.hit_count >= self.write_through_threshold_storage
|
142
|
+
):
|
143
|
+
# if the node is backuped on host memory but not on storage
|
144
|
+
self.write_backup_storage(node)
|
118
145
|
|
119
146
|
def writing_check(self, write_back=False):
|
120
147
|
if write_back:
|
@@ -221,6 +248,10 @@ class HiRadixCache(RadixCache):
|
|
221
248
|
if not x.evicted:
|
222
249
|
continue
|
223
250
|
|
251
|
+
# node is protected from eviction as it has ongoing prefetch or backup to storage
|
252
|
+
if x.host_ref_counter > 0:
|
253
|
+
continue
|
254
|
+
|
224
255
|
num_evicted += self.cache_controller.evict_host(x.host_value)
|
225
256
|
|
226
257
|
for k, v in x.parent.children.items():
|
@@ -314,6 +345,85 @@ class HiRadixCache(RadixCache):
|
|
314
345
|
def check_hicache_events(self):
|
315
346
|
self.writing_check()
|
316
347
|
self.loading_check()
|
348
|
+
if self.enable_storage:
|
349
|
+
self.check_revoked_prefetch()
|
350
|
+
self.check_backup_progress()
|
351
|
+
|
352
|
+
def check_revoked_prefetch(self):
|
353
|
+
queue_size = torch.tensor(
|
354
|
+
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
|
355
|
+
)
|
356
|
+
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
357
|
+
# synchrnoize TP workers to make the same update to hiradix cache
|
358
|
+
torch.distributed.all_reduce(
|
359
|
+
queue_size,
|
360
|
+
op=torch.distributed.ReduceOp.MIN,
|
361
|
+
group=self.tp_group,
|
362
|
+
)
|
363
|
+
for _ in range(queue_size.item()):
|
364
|
+
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
365
|
+
if req_id in self.ongoing_prefetch:
|
366
|
+
last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
|
367
|
+
last_host_node.release_host()
|
368
|
+
self.cache_controller.mem_pool_host.free(host_indices)
|
369
|
+
del self.ongoing_prefetch[req_id]
|
370
|
+
|
371
|
+
def check_backup_progress(self):
|
372
|
+
queue_size = torch.tensor(
|
373
|
+
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
|
374
|
+
)
|
375
|
+
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
376
|
+
# synchrnoize TP workers to make the same update to hiradix cache
|
377
|
+
torch.distributed.all_reduce(
|
378
|
+
queue_size,
|
379
|
+
op=torch.distributed.ReduceOp.MIN,
|
380
|
+
group=self.tp_group,
|
381
|
+
)
|
382
|
+
for _ in range(queue_size.item()):
|
383
|
+
ack_id, hash_value = self.cache_controller.ack_backup_queue.get()
|
384
|
+
self.ongoing_backup[ack_id].hash_value = hash_value
|
385
|
+
self.ongoing_backup[ack_id].release_host()
|
386
|
+
del self.ongoing_backup[ack_id]
|
387
|
+
|
388
|
+
def check_prefetch_progress(self, req_id: str):
|
389
|
+
if req_id not in self.ongoing_prefetch:
|
390
|
+
# there is no ongoing prefetch for this request or it has been revoked
|
391
|
+
return
|
392
|
+
|
393
|
+
# todo: more policies for prefetch progress such as timeout
|
394
|
+
# the current policy is to prefetch with best effort and terminate when queuing is over
|
395
|
+
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
396
|
+
req_id
|
397
|
+
]
|
398
|
+
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
399
|
+
operation
|
400
|
+
)
|
401
|
+
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
402
|
+
|
403
|
+
min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int)
|
404
|
+
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
405
|
+
# synchrnoize TP workers to make the same update to hiradix cache
|
406
|
+
torch.distributed.all_reduce(
|
407
|
+
min_completed_tokens,
|
408
|
+
op=torch.distributed.ReduceOp.MIN,
|
409
|
+
group=self.tp_group,
|
410
|
+
)
|
411
|
+
min_completed_tokens = min_completed_tokens.item()
|
412
|
+
fetched_token_ids = token_ids[:min_completed_tokens]
|
413
|
+
written_indices = host_indices[:min_completed_tokens]
|
414
|
+
matched_length = self._insert_helper_host(
|
415
|
+
last_host_node,
|
416
|
+
fetched_token_ids,
|
417
|
+
written_indices,
|
418
|
+
hash_value[:min_completed_tokens],
|
419
|
+
)
|
420
|
+
|
421
|
+
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
422
|
+
self.cache_controller.mem_pool_host.free(
|
423
|
+
host_indices[min_completed_tokens:completed_tokens]
|
424
|
+
)
|
425
|
+
last_host_node.release_host()
|
426
|
+
del self.ongoing_prefetch[req_id]
|
317
427
|
|
318
428
|
def match_prefix(self, key: List[int], **kwargs):
|
319
429
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
@@ -348,6 +458,71 @@ class HiRadixCache(RadixCache):
|
|
348
458
|
host_hit_length=host_hit_length,
|
349
459
|
)
|
350
460
|
|
461
|
+
def prefetch_from_storage(
|
462
|
+
self,
|
463
|
+
req_id: str,
|
464
|
+
last_host_node: TreeNode,
|
465
|
+
new_input_tokens: List[int],
|
466
|
+
last_hash: Optional[str] = None,
|
467
|
+
):
|
468
|
+
if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold:
|
469
|
+
return
|
470
|
+
|
471
|
+
last_host_node.protect_host()
|
472
|
+
host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens))
|
473
|
+
if host_indices is None:
|
474
|
+
self.evict_host(len(new_input_tokens))
|
475
|
+
host_indices = self.cache_controller.mem_pool_host.alloc(
|
476
|
+
len(new_input_tokens)
|
477
|
+
)
|
478
|
+
if host_indices is None:
|
479
|
+
last_host_node.release_host()
|
480
|
+
# no sufficient host memory to prefetch
|
481
|
+
return
|
482
|
+
operation = self.cache_controller.prefetch(
|
483
|
+
req_id, host_indices, new_input_tokens, last_hash
|
484
|
+
)
|
485
|
+
self.ongoing_prefetch[req_id] = (
|
486
|
+
last_host_node,
|
487
|
+
new_input_tokens,
|
488
|
+
host_indices,
|
489
|
+
operation,
|
490
|
+
)
|
491
|
+
|
492
|
+
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
493
|
+
node.last_access_time = time.monotonic()
|
494
|
+
if len(key) == 0:
|
495
|
+
return 0
|
496
|
+
|
497
|
+
child_key = self.get_child_key_fn(key)
|
498
|
+
|
499
|
+
matched_length = 0
|
500
|
+
while len(key) > 0 and child_key in node.children.keys():
|
501
|
+
node = node.children[child_key]
|
502
|
+
node.last_access_time = time.monotonic()
|
503
|
+
prefix_len = self.key_match_fn(node.key, key)
|
504
|
+
key = key[prefix_len:]
|
505
|
+
host_value = host_value[prefix_len:]
|
506
|
+
hash_value = hash_value[prefix_len:]
|
507
|
+
matched_length += prefix_len
|
508
|
+
|
509
|
+
if prefix_len < len(node.key):
|
510
|
+
new_node = self._split_node(node.key, node, prefix_len)
|
511
|
+
node = new_node
|
512
|
+
|
513
|
+
if len(key):
|
514
|
+
child_key = self.get_child_key_fn(key)
|
515
|
+
|
516
|
+
if len(key):
|
517
|
+
new_node = TreeNode()
|
518
|
+
new_node.parent = node
|
519
|
+
new_node.key = key
|
520
|
+
new_node.value = None
|
521
|
+
new_node.host_value = host_value
|
522
|
+
new_node.hash_value = hash_value
|
523
|
+
node.children[child_key] = new_node
|
524
|
+
return matched_length
|
525
|
+
|
351
526
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
352
527
|
node.last_access_time = time.monotonic()
|
353
528
|
child_key = self.get_child_key_fn(key)
|
@@ -520,8 +520,13 @@ class SWAKVPool(KVCache):
|
|
520
520
|
self.layers_mapping[global_layer_id] = (swa_layer_id, True)
|
521
521
|
self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
|
522
522
|
|
523
|
+
k_size, v_size = self.get_kv_size_bytes()
|
524
|
+
self.mem_usage = (k_size + v_size) / GB
|
525
|
+
|
523
526
|
def get_kv_size_bytes(self):
|
524
|
-
|
527
|
+
k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
|
528
|
+
k_size_swa, v_size_swa = self.swa_kv_pool.get_kv_size_bytes()
|
529
|
+
return k_size + k_size_swa, v_size + v_size_swa
|
525
530
|
|
526
531
|
def get_contiguous_buf_infos(self):
|
527
532
|
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
|
@@ -597,6 +602,16 @@ class SWAKVPool(KVCache):
|
|
597
602
|
layer_id_override=layer_id_pool,
|
598
603
|
)
|
599
604
|
|
605
|
+
def load_from_host_per_layer(
|
606
|
+
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
607
|
+
):
|
608
|
+
raise NotImplementedError("HiCache not supported for SWAKVPool.")
|
609
|
+
|
610
|
+
def backup_to_host_all_layer(
|
611
|
+
self, host_pool, host_indices, device_indices, io_backend
|
612
|
+
):
|
613
|
+
raise NotImplementedError("HiCache not supported for SWAKVPool.")
|
614
|
+
|
600
615
|
|
601
616
|
class AscendTokenToKVPool(MHATokenToKVPool):
|
602
617
|
|