sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
|
|
35
35
|
hicache_size: int,
|
36
36
|
hicache_write_policy: str,
|
37
37
|
hicache_io_backend: str,
|
38
|
+
hicache_storage_backend: Optional[str] = None,
|
38
39
|
):
|
39
40
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
40
41
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
@@ -49,25 +50,36 @@ class HiRadixCache(RadixCache):
|
|
49
50
|
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
50
51
|
|
51
52
|
self.tp_group = tp_cache_group
|
53
|
+
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
54
|
+
self.enable_storage = hicache_storage_backend is not None
|
55
|
+
# todo: customizable storage prefetch threshold
|
56
|
+
self.prefetch_threshold = 256
|
52
57
|
|
53
58
|
self.load_cache_event = threading.Event()
|
54
59
|
self.cache_controller = HiCacheController(
|
55
60
|
token_to_kv_pool_allocator,
|
56
61
|
self.token_to_kv_pool_host,
|
57
62
|
page_size,
|
63
|
+
self.tp_group,
|
58
64
|
load_cache_event=self.load_cache_event,
|
59
65
|
write_policy=hicache_write_policy,
|
60
66
|
io_backend=hicache_io_backend,
|
67
|
+
storage_backend=hicache_storage_backend,
|
68
|
+
prefetch_threshold=self.prefetch_threshold,
|
61
69
|
)
|
62
70
|
|
63
71
|
# record the nodes with ongoing write through
|
64
72
|
self.ongoing_write_through = {}
|
65
73
|
# record the node segments with ongoing load back
|
66
74
|
self.ongoing_load_back = {}
|
75
|
+
# record the ongoing prefetch requests
|
76
|
+
self.ongoing_prefetch = {}
|
77
|
+
self.ongoing_backup = {}
|
67
78
|
# todo: dynamically adjust the threshold
|
68
79
|
self.write_through_threshold = (
|
69
80
|
1 if hicache_write_policy == "write_through" else 3
|
70
81
|
)
|
82
|
+
self.write_through_threshold_storage = 3
|
71
83
|
self.load_back_threshold = 10
|
72
84
|
super().__init__(
|
73
85
|
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
@@ -108,13 +120,30 @@ class HiRadixCache(RadixCache):
|
|
108
120
|
|
109
121
|
return len(host_indices)
|
110
122
|
|
123
|
+
def write_backup_storage(self, node: TreeNode):
|
124
|
+
operation_id = self.cache_controller.write_storage(
|
125
|
+
node.host_value, node.key, node.parent.get_last_hash_value()
|
126
|
+
)
|
127
|
+
self.ongoing_backup[operation_id] = node
|
128
|
+
node.protect_host()
|
129
|
+
|
111
130
|
def inc_hit_count(self, node: TreeNode):
|
112
|
-
if
|
131
|
+
if self.cache_controller.write_policy == "write_back":
|
113
132
|
return
|
114
133
|
node.hit_count += 1
|
115
|
-
|
116
|
-
|
117
|
-
node.hit_count
|
134
|
+
|
135
|
+
if not node.backuped:
|
136
|
+
if node.hit_count >= self.write_through_threshold:
|
137
|
+
# write to host if the node is not backuped
|
138
|
+
self.write_backup(node)
|
139
|
+
else:
|
140
|
+
if (
|
141
|
+
self.enable_storage
|
142
|
+
and (not node.backuped_storage)
|
143
|
+
and node.hit_count >= self.write_through_threshold_storage
|
144
|
+
):
|
145
|
+
# if the node is backuped on host memory but not on storage
|
146
|
+
self.write_backup_storage(node)
|
118
147
|
|
119
148
|
def writing_check(self, write_back=False):
|
120
149
|
if write_back:
|
@@ -126,7 +155,7 @@ class HiRadixCache(RadixCache):
|
|
126
155
|
queue_size = torch.tensor(
|
127
156
|
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
128
157
|
)
|
129
|
-
if
|
158
|
+
if self.tp_world_size > 1:
|
130
159
|
# synchrnoize TP workers to make the same update to radix cache
|
131
160
|
torch.distributed.all_reduce(
|
132
161
|
queue_size,
|
@@ -221,6 +250,10 @@ class HiRadixCache(RadixCache):
|
|
221
250
|
if not x.evicted:
|
222
251
|
continue
|
223
252
|
|
253
|
+
# node is protected from eviction as it has ongoing prefetch or backup to storage
|
254
|
+
if x.host_ref_counter > 0:
|
255
|
+
continue
|
256
|
+
|
224
257
|
num_evicted += self.cache_controller.evict_host(x.host_value)
|
225
258
|
|
226
259
|
for k, v in x.parent.children.items():
|
@@ -314,6 +347,94 @@ class HiRadixCache(RadixCache):
|
|
314
347
|
def check_hicache_events(self):
|
315
348
|
self.writing_check()
|
316
349
|
self.loading_check()
|
350
|
+
if self.enable_storage:
|
351
|
+
self.check_revoked_prefetch()
|
352
|
+
self.check_backup_progress()
|
353
|
+
|
354
|
+
def check_revoked_prefetch(self):
|
355
|
+
queue_size = torch.tensor(
|
356
|
+
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
|
357
|
+
)
|
358
|
+
if self.tp_world_size > 1:
|
359
|
+
# synchrnoize TP workers to make the same update to hiradix cache
|
360
|
+
torch.distributed.all_reduce(
|
361
|
+
queue_size,
|
362
|
+
op=torch.distributed.ReduceOp.MIN,
|
363
|
+
group=self.tp_group,
|
364
|
+
)
|
365
|
+
for _ in range(queue_size.item()):
|
366
|
+
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
367
|
+
if req_id in self.ongoing_prefetch:
|
368
|
+
last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
|
369
|
+
last_host_node.release_host()
|
370
|
+
self.cache_controller.mem_pool_host.free(host_indices)
|
371
|
+
del self.ongoing_prefetch[req_id]
|
372
|
+
|
373
|
+
def check_backup_progress(self):
|
374
|
+
queue_size = torch.tensor(
|
375
|
+
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
|
376
|
+
)
|
377
|
+
if self.tp_world_size > 1:
|
378
|
+
# synchrnoize TP workers to make the same update to hiradix cache
|
379
|
+
torch.distributed.all_reduce(
|
380
|
+
queue_size,
|
381
|
+
op=torch.distributed.ReduceOp.MIN,
|
382
|
+
group=self.tp_group,
|
383
|
+
)
|
384
|
+
for _ in range(queue_size.item()):
|
385
|
+
ack_id, hash_value, completed_tokens = (
|
386
|
+
self.cache_controller.ack_backup_queue.get()
|
387
|
+
)
|
388
|
+
host_node = self.ongoing_backup[ack_id]
|
389
|
+
if completed_tokens < len(host_node.key):
|
390
|
+
# backup is only partially successful, split the node
|
391
|
+
new_node = self._split_node(host_node.key, host_node, completed_tokens)
|
392
|
+
new_node.hash_value = hash_value
|
393
|
+
host_node.release_host()
|
394
|
+
del self.ongoing_backup[ack_id]
|
395
|
+
|
396
|
+
def check_prefetch_progress(self, req_id: str):
|
397
|
+
if req_id not in self.ongoing_prefetch:
|
398
|
+
# there is no ongoing prefetch for this request or it has been revoked
|
399
|
+
return
|
400
|
+
|
401
|
+
# todo: more policies for prefetch progress such as timeout
|
402
|
+
# the current policy is to prefetch with best effort and terminate when queuing is over
|
403
|
+
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
404
|
+
req_id
|
405
|
+
]
|
406
|
+
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
407
|
+
operation
|
408
|
+
)
|
409
|
+
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
410
|
+
|
411
|
+
min_completed_tokens = completed_tokens
|
412
|
+
if self.tp_world_size > 1:
|
413
|
+
# synchrnoize TP workers to make the same update to hiradix cache
|
414
|
+
completed_tokens_tensor = torch.tensor(
|
415
|
+
min_completed_tokens, dtype=torch.int
|
416
|
+
)
|
417
|
+
torch.distributed.all_reduce(
|
418
|
+
completed_tokens_tensor,
|
419
|
+
op=torch.distributed.ReduceOp.MIN,
|
420
|
+
group=self.tp_group,
|
421
|
+
)
|
422
|
+
min_completed_tokens = completed_tokens_tensor.item()
|
423
|
+
fetched_token_ids = token_ids[:min_completed_tokens]
|
424
|
+
written_indices = host_indices[:min_completed_tokens]
|
425
|
+
matched_length = self._insert_helper_host(
|
426
|
+
last_host_node,
|
427
|
+
fetched_token_ids,
|
428
|
+
written_indices,
|
429
|
+
hash_value[:min_completed_tokens],
|
430
|
+
)
|
431
|
+
|
432
|
+
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
433
|
+
self.cache_controller.mem_pool_host.free(
|
434
|
+
host_indices[min_completed_tokens:completed_tokens]
|
435
|
+
)
|
436
|
+
last_host_node.release_host()
|
437
|
+
del self.ongoing_prefetch[req_id]
|
317
438
|
|
318
439
|
def match_prefix(self, key: List[int], **kwargs):
|
319
440
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
@@ -348,6 +469,74 @@ class HiRadixCache(RadixCache):
|
|
348
469
|
host_hit_length=host_hit_length,
|
349
470
|
)
|
350
471
|
|
472
|
+
def prefetch_from_storage(
|
473
|
+
self,
|
474
|
+
req_id: str,
|
475
|
+
last_host_node: TreeNode,
|
476
|
+
new_input_tokens: List[int],
|
477
|
+
last_hash: Optional[str] = None,
|
478
|
+
):
|
479
|
+
# align the number of fetching tokens to the page size
|
480
|
+
prefetch_length = len(new_input_tokens) - (
|
481
|
+
len(new_input_tokens) % self.page_size
|
482
|
+
)
|
483
|
+
new_input_tokens = new_input_tokens[:prefetch_length]
|
484
|
+
if not self.enable_storage or prefetch_length < self.prefetch_threshold:
|
485
|
+
return
|
486
|
+
|
487
|
+
last_host_node.protect_host()
|
488
|
+
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
489
|
+
if host_indices is None:
|
490
|
+
self.evict_host(prefetch_length)
|
491
|
+
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
492
|
+
if host_indices is None:
|
493
|
+
last_host_node.release_host()
|
494
|
+
# no sufficient host memory to prefetch
|
495
|
+
return
|
496
|
+
operation = self.cache_controller.prefetch(
|
497
|
+
req_id, host_indices, new_input_tokens, last_hash
|
498
|
+
)
|
499
|
+
self.ongoing_prefetch[req_id] = (
|
500
|
+
last_host_node,
|
501
|
+
new_input_tokens,
|
502
|
+
host_indices,
|
503
|
+
operation,
|
504
|
+
)
|
505
|
+
|
506
|
+
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
507
|
+
node.last_access_time = time.monotonic()
|
508
|
+
if len(key) == 0:
|
509
|
+
return 0
|
510
|
+
|
511
|
+
child_key = self.get_child_key_fn(key)
|
512
|
+
|
513
|
+
matched_length = 0
|
514
|
+
while len(key) > 0 and child_key in node.children.keys():
|
515
|
+
node = node.children[child_key]
|
516
|
+
node.last_access_time = time.monotonic()
|
517
|
+
prefix_len = self.key_match_fn(node.key, key)
|
518
|
+
key = key[prefix_len:]
|
519
|
+
host_value = host_value[prefix_len:]
|
520
|
+
hash_value = hash_value[prefix_len:]
|
521
|
+
matched_length += prefix_len
|
522
|
+
|
523
|
+
if prefix_len < len(node.key):
|
524
|
+
new_node = self._split_node(node.key, node, prefix_len)
|
525
|
+
node = new_node
|
526
|
+
|
527
|
+
if len(key):
|
528
|
+
child_key = self.get_child_key_fn(key)
|
529
|
+
|
530
|
+
if len(key):
|
531
|
+
new_node = TreeNode()
|
532
|
+
new_node.parent = node
|
533
|
+
new_node.key = key
|
534
|
+
new_node.value = None
|
535
|
+
new_node.host_value = host_value
|
536
|
+
new_node.hash_value = hash_value
|
537
|
+
node.children[child_key] = new_node
|
538
|
+
return matched_length
|
539
|
+
|
351
540
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
352
541
|
node.last_access_time = time.monotonic()
|
353
542
|
child_key = self.get_child_key_fn(key)
|
@@ -520,8 +520,13 @@ class SWAKVPool(KVCache):
|
|
520
520
|
self.layers_mapping[global_layer_id] = (swa_layer_id, True)
|
521
521
|
self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
|
522
522
|
|
523
|
+
k_size, v_size = self.get_kv_size_bytes()
|
524
|
+
self.mem_usage = (k_size + v_size) / GB
|
525
|
+
|
523
526
|
def get_kv_size_bytes(self):
|
524
|
-
|
527
|
+
k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
|
528
|
+
k_size_swa, v_size_swa = self.swa_kv_pool.get_kv_size_bytes()
|
529
|
+
return k_size + k_size_swa, v_size + v_size_swa
|
525
530
|
|
526
531
|
def get_contiguous_buf_infos(self):
|
527
532
|
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
|
@@ -597,6 +602,16 @@ class SWAKVPool(KVCache):
|
|
597
602
|
layer_id_override=layer_id_pool,
|
598
603
|
)
|
599
604
|
|
605
|
+
def load_from_host_per_layer(
|
606
|
+
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
607
|
+
):
|
608
|
+
raise NotImplementedError("HiCache not supported for SWAKVPool.")
|
609
|
+
|
610
|
+
def backup_to_host_all_layer(
|
611
|
+
self, host_pool, host_indices, device_indices, io_backend
|
612
|
+
):
|
613
|
+
raise NotImplementedError("HiCache not supported for SWAKVPool.")
|
614
|
+
|
600
615
|
|
601
616
|
class AscendTokenToKVPool(MHATokenToKVPool):
|
602
617
|
|
@@ -71,11 +71,12 @@ class HostKVCache(abc.ABC):
|
|
71
71
|
requested_bytes = self.size * self.size_per_token
|
72
72
|
# preserve at least 10GB for other usage
|
73
73
|
ten_gb = 10 * (1024**3)
|
74
|
-
|
74
|
+
available_bytes = host_mem.available - ten_gb
|
75
|
+
if requested_bytes > available_bytes:
|
75
76
|
raise ValueError(
|
76
77
|
f"Not enough host memory available. Requesting "
|
77
78
|
f"{requested_bytes / 1e9:.2f} GB but only have "
|
78
|
-
f"{
|
79
|
+
f"{available_bytes / 1e9:.2f} GB free. Please reduce the "
|
79
80
|
f"size of the hierarchical cache."
|
80
81
|
)
|
81
82
|
else:
|
@@ -98,6 +99,20 @@ class HostKVCache(abc.ABC):
|
|
98
99
|
def init_kv_buffer(self):
|
99
100
|
raise NotImplementedError()
|
100
101
|
|
102
|
+
@abc.abstractmethod
|
103
|
+
def get_flat_data_page(self, index) -> torch.Tensor:
|
104
|
+
"""
|
105
|
+
Get a flat data page from the host memory pool.
|
106
|
+
"""
|
107
|
+
raise NotImplementedError()
|
108
|
+
|
109
|
+
@abc.abstractmethod
|
110
|
+
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
111
|
+
"""
|
112
|
+
Set a flat data page to the host memory pool.
|
113
|
+
"""
|
114
|
+
raise NotImplementedError()
|
115
|
+
|
101
116
|
@synchronized()
|
102
117
|
def clear(self):
|
103
118
|
# Initialize memory states and tracking structures.
|
@@ -111,6 +126,9 @@ class HostKVCache(abc.ABC):
|
|
111
126
|
|
112
127
|
@synchronized()
|
113
128
|
def alloc(self, need_size: int) -> torch.Tensor:
|
129
|
+
assert (
|
130
|
+
need_size % self.page_size == 0
|
131
|
+
), "The requested size should be a multiple of the page size."
|
114
132
|
if need_size > self.available_size():
|
115
133
|
return None
|
116
134
|
|
@@ -226,6 +244,19 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
226
244
|
pin_memory=self.pin_memory,
|
227
245
|
)
|
228
246
|
|
247
|
+
# todo, page first memory layout
|
248
|
+
def get_flat_data_page(self, index) -> torch.Tensor:
|
249
|
+
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
250
|
+
|
251
|
+
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
252
|
+
self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
|
253
|
+
2,
|
254
|
+
self.layer_num,
|
255
|
+
self.page_size,
|
256
|
+
self.head_num,
|
257
|
+
self.head_dim,
|
258
|
+
)
|
259
|
+
|
229
260
|
@property
|
230
261
|
def k_buffer(self):
|
231
262
|
return self.kv_buffer[0]
|
@@ -275,3 +306,14 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
275
306
|
device=self.device,
|
276
307
|
pin_memory=self.pin_memory,
|
277
308
|
)
|
309
|
+
|
310
|
+
def get_flat_data_page(self, index) -> torch.Tensor:
|
311
|
+
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
312
|
+
|
313
|
+
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
314
|
+
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
|
315
|
+
self.layer_num,
|
316
|
+
self.page_size,
|
317
|
+
1,
|
318
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
319
|
+
)
|
@@ -55,8 +55,13 @@ class TreeNode:
|
|
55
55
|
self.hit_count = 0
|
56
56
|
# indicating the node is loading KV cache from host
|
57
57
|
self.loading = False
|
58
|
+
# indicating the node is locked to protect from eviction
|
59
|
+
# incremented when the node is referenced by a storage operation
|
60
|
+
self.host_ref_counter = 0
|
58
61
|
# store the host indices of KV cache
|
59
62
|
self.host_value: Optional[torch.Tensor] = None
|
63
|
+
# store hash values of each pages
|
64
|
+
self.hash_value: Optional[List[str]] = None
|
60
65
|
|
61
66
|
self.id = TreeNode.counter if id is None else id
|
62
67
|
TreeNode.counter += 1
|
@@ -69,6 +74,27 @@ class TreeNode:
|
|
69
74
|
def backuped(self):
|
70
75
|
return self.host_value is not None
|
71
76
|
|
77
|
+
@property
|
78
|
+
def backuped_storage(self):
|
79
|
+
return self.hash_value is not None and len(self.hash_value) > 0
|
80
|
+
|
81
|
+
def protect_host(self):
|
82
|
+
"""Protect the host value from eviction."""
|
83
|
+
self.host_ref_counter += 1
|
84
|
+
|
85
|
+
def release_host(self):
|
86
|
+
"""Release the host value, allowing it to be evicted."""
|
87
|
+
if self.host_ref_counter > 0:
|
88
|
+
self.host_ref_counter -= 1
|
89
|
+
else:
|
90
|
+
raise RuntimeError("Host reference counter is already zero.")
|
91
|
+
|
92
|
+
def get_last_hash_value(self) -> Optional[str]:
|
93
|
+
"""Returns the hash value of the last page in this node."""
|
94
|
+
if self.hash_value is None or len(self.hash_value) == 0:
|
95
|
+
return None
|
96
|
+
return self.hash_value[-1]
|
97
|
+
|
72
98
|
def __lt__(self, other: "TreeNode"):
|
73
99
|
return self.last_access_time < other.last_access_time
|
74
100
|
|