sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -0
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +7 -7
- sglang/srt/disaggregation/decode.py +8 -3
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +4 -5
- sglang/srt/entrypoints/openai/protocol.py +0 -9
- sglang/srt/entrypoints/openai/serving_chat.py +59 -265
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +8 -10
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/quantization/__init__.py +5 -3
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/modelopt_quant.py +6 -11
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +21 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +6 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +35 -20
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +15 -7
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +25 -26
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +22 -3
- sglang/srt/model_executor/forward_batch_info.py +26 -5
- sglang/srt/model_executor/model_runner.py +129 -35
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_v2.py +74 -35
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +9 -9
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +136 -19
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/server_args.py +115 -139
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +12 -4
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -43,12 +43,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
43
43
|
dtype: torch.dtype,
|
44
44
|
device: str,
|
45
45
|
kvcache: KVCache,
|
46
|
+
need_sort: bool,
|
46
47
|
):
|
47
48
|
self.size = size
|
48
49
|
self.page_size = page_size
|
49
50
|
self.dtype = dtype
|
50
51
|
self.device = device
|
51
52
|
self._kvcache = kvcache
|
53
|
+
self.need_sort = need_sort
|
52
54
|
|
53
55
|
self.free_pages = None
|
54
56
|
self.release_pages = None
|
@@ -79,6 +81,9 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
79
81
|
if self.free_group:
|
80
82
|
self.free(torch.cat(self.free_group))
|
81
83
|
|
84
|
+
def estimated_num_new_pages(self, bs, extend_num_tokens):
|
85
|
+
return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
|
86
|
+
|
82
87
|
def merge_and_sort_free(self):
|
83
88
|
if len(self.release_pages) > 0:
|
84
89
|
self.free_pages = torch.cat((self.free_pages, self.release_pages))
|
@@ -117,8 +122,15 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
117
122
|
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
118
123
|
"""An allocator managing the indices to kv cache data."""
|
119
124
|
|
120
|
-
def __init__(
|
121
|
-
|
125
|
+
def __init__(
|
126
|
+
self,
|
127
|
+
size: int,
|
128
|
+
dtype: torch.dtype,
|
129
|
+
device: str,
|
130
|
+
kvcache: KVCache,
|
131
|
+
need_sort: bool,
|
132
|
+
):
|
133
|
+
super().__init__(size, 1, dtype, device, kvcache, need_sort)
|
122
134
|
self.clear()
|
123
135
|
|
124
136
|
def clear(self):
|
@@ -135,7 +147,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
135
147
|
return len(self.free_pages) + len(self.release_pages)
|
136
148
|
|
137
149
|
def alloc(self, need_size: int):
|
138
|
-
if need_size > len(self.free_pages):
|
150
|
+
if self.need_sort and need_size > len(self.free_pages):
|
139
151
|
self.merge_and_sort_free()
|
140
152
|
if need_size > len(self.free_pages):
|
141
153
|
return None
|
@@ -149,7 +161,10 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
149
161
|
return
|
150
162
|
|
151
163
|
if self.is_not_in_free_group:
|
152
|
-
|
164
|
+
if self.need_sort:
|
165
|
+
self.release_pages = torch.cat((self.release_pages, free_index))
|
166
|
+
else:
|
167
|
+
self.free_pages = torch.cat((self.free_pages, free_index))
|
153
168
|
else:
|
154
169
|
self.free_group.append(free_index)
|
155
170
|
|
@@ -170,8 +185,9 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
170
185
|
dtype: torch.dtype,
|
171
186
|
device: str,
|
172
187
|
kvcache: SWAKVPool,
|
188
|
+
need_sort: bool,
|
173
189
|
):
|
174
|
-
super().__init__(size, 1, dtype, device, kvcache)
|
190
|
+
super().__init__(size, 1, dtype, device, kvcache, need_sort)
|
175
191
|
assert isinstance(kvcache, SWAKVPool)
|
176
192
|
self._size_full = size
|
177
193
|
self._size_swa = size_swa
|
@@ -180,12 +196,14 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
180
196
|
dtype,
|
181
197
|
device,
|
182
198
|
kvcache.full_kv_pool,
|
199
|
+
need_sort,
|
183
200
|
)
|
184
201
|
self.swa_attn_allocator = TokenToKVPoolAllocator(
|
185
202
|
size_swa,
|
186
203
|
dtype,
|
187
204
|
device,
|
188
205
|
kvcache.swa_kv_pool,
|
206
|
+
need_sort,
|
189
207
|
)
|
190
208
|
self.full_to_swa_index_mapping = torch.empty(
|
191
209
|
size + size_swa + 1,
|
@@ -418,8 +436,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
418
436
|
dtype: torch.dtype,
|
419
437
|
device: str,
|
420
438
|
kvcache: KVCache,
|
439
|
+
need_sort: bool,
|
421
440
|
):
|
422
|
-
super().__init__(size, page_size, dtype, device, kvcache)
|
441
|
+
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
423
442
|
self.num_pages = size // page_size
|
424
443
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
425
444
|
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
@@ -433,7 +452,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
433
452
|
), "The allocation size should be page-aligned"
|
434
453
|
|
435
454
|
num_pages = need_size // self.page_size
|
436
|
-
if num_pages > len(self.free_pages):
|
455
|
+
if self.need_sort and num_pages > len(self.free_pages):
|
437
456
|
self.merge_and_sort_free()
|
438
457
|
if num_pages > len(self.free_pages):
|
439
458
|
return None
|
@@ -460,18 +479,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
460
479
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
461
480
|
)
|
462
481
|
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
)
|
468
|
-
.sum()
|
469
|
-
.item()
|
470
|
-
)
|
471
|
-
if estimated_num_new_pages > len(self.free_pages):
|
482
|
+
bs = len(prefix_lens)
|
483
|
+
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
|
484
|
+
self.free_pages
|
485
|
+
):
|
472
486
|
self.merge_and_sort_free()
|
473
487
|
|
474
|
-
bs = len(prefix_lens)
|
475
488
|
out_indices = torch.empty(
|
476
489
|
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
477
490
|
)
|
@@ -508,18 +521,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
508
521
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
509
522
|
)
|
510
523
|
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
)
|
516
|
-
.sum()
|
517
|
-
.item()
|
518
|
-
)
|
519
|
-
if estimated_num_new_pages > len(self.free_pages):
|
524
|
+
bs = len(seq_lens)
|
525
|
+
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
|
526
|
+
self.free_pages
|
527
|
+
):
|
520
528
|
self.merge_and_sort_free()
|
521
529
|
|
522
|
-
bs = len(seq_lens)
|
523
530
|
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
524
531
|
alloc_decode_kernel[(bs,)](
|
525
532
|
seq_lens,
|
@@ -547,7 +554,10 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
547
554
|
|
548
555
|
if self.is_not_in_free_group:
|
549
556
|
free_page_indices = torch.unique(free_index // self.page_size)
|
550
|
-
|
557
|
+
if self.need_sort:
|
558
|
+
self.release_pages = torch.cat((free_page_indices, self.release_pages))
|
559
|
+
else:
|
560
|
+
self.free_pages = torch.cat((free_page_indices, self.free_pages))
|
551
561
|
else:
|
552
562
|
self.free_group.append(free_index)
|
553
563
|
|
@@ -622,27 +632,6 @@ def alloc_extend_kernel_ascend(
|
|
622
632
|
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
623
633
|
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
624
634
|
).view(-1)
|
625
|
-
return num_new_pages
|
626
|
-
|
627
|
-
|
628
|
-
def alloc_decode_kernel_ascend(
|
629
|
-
seq_lens,
|
630
|
-
last_loc,
|
631
|
-
free_pages,
|
632
|
-
out_indices,
|
633
|
-
page_size,
|
634
|
-
):
|
635
|
-
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
636
|
-
seq_lens - 1 + page_size - 1
|
637
|
-
) // page_size
|
638
|
-
end_new_pages = torch.cumsum(num_new_pages, 0)
|
639
|
-
start_new_pages = end_new_pages - num_new_pages
|
640
|
-
for i in range(len(seq_lens)):
|
641
|
-
if num_new_pages[i]:
|
642
|
-
out_indices[i] = free_pages[start_new_pages[i]] * page_size
|
643
|
-
else:
|
644
|
-
out_indices[i] = last_loc[i] + 1
|
645
|
-
return num_new_pages
|
646
635
|
|
647
636
|
|
648
637
|
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
@@ -654,9 +643,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
654
643
|
dtype: torch.dtype,
|
655
644
|
device: str,
|
656
645
|
kvcache: KVCache,
|
646
|
+
need_sort: bool,
|
657
647
|
):
|
658
|
-
super().__init__(size, page_size, dtype, device, kvcache)
|
659
|
-
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
|
648
|
+
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
660
649
|
|
661
650
|
def alloc_extend(
|
662
651
|
self,
|
@@ -678,15 +667,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
678
667
|
.sum()
|
679
668
|
.item()
|
680
669
|
)
|
681
|
-
if estimated_num_new_pages > len(self.free_pages):
|
670
|
+
if self.need_sort and estimated_num_new_pages > len(self.free_pages):
|
682
671
|
self.merge_and_sort_free()
|
683
672
|
|
684
|
-
|
673
|
+
if estimated_num_new_pages > len(self.free_pages):
|
674
|
+
return None
|
675
|
+
|
685
676
|
out_indices = torch.empty(
|
686
677
|
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
687
678
|
)
|
688
679
|
|
689
|
-
|
680
|
+
alloc_extend_kernel_ascend(
|
690
681
|
prefix_lens,
|
691
682
|
seq_lens,
|
692
683
|
last_loc,
|
@@ -699,11 +690,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
699
690
|
if self.debug_mode:
|
700
691
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
701
692
|
|
702
|
-
|
703
|
-
if num_new_pages > len(self.free_pages):
|
704
|
-
return None
|
705
|
-
|
706
|
-
self.free_pages = self.free_pages[num_new_pages:]
|
693
|
+
self.free_pages = self.free_pages[estimated_num_new_pages:]
|
707
694
|
return out_indices
|
708
695
|
|
709
696
|
def alloc_decode(
|
@@ -716,39 +703,26 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
716
703
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
717
704
|
)
|
718
705
|
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
)
|
724
|
-
.sum()
|
725
|
-
.item()
|
726
|
-
)
|
727
|
-
if estimated_num_new_pages > len(self.free_pages):
|
706
|
+
need_new_pages = (seq_lens % self.page_size == 1).int()
|
707
|
+
num_new_pages = need_new_pages.sum().item()
|
708
|
+
|
709
|
+
if num_new_pages > len(self.free_pages):
|
728
710
|
self.merge_and_sort_free()
|
729
711
|
|
730
|
-
|
731
|
-
|
712
|
+
if num_new_pages > len(self.free_pages):
|
713
|
+
return None
|
732
714
|
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
self.
|
739
|
-
|
715
|
+
end_new_pages = torch.cumsum(need_new_pages, 0)
|
716
|
+
start_new_pages = end_new_pages - need_new_pages
|
717
|
+
if num_new_pages == 0:
|
718
|
+
out_indices = last_loc + 1
|
719
|
+
else:
|
720
|
+
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
|
721
|
+
start_new_pages
|
722
|
+
] * self.page_size * need_new_pages
|
740
723
|
|
741
724
|
if self.debug_mode:
|
742
725
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
743
726
|
|
744
|
-
num_new_pages = self.ret_values.sum()
|
745
|
-
if num_new_pages > len(self.free_pages):
|
746
|
-
return None
|
747
|
-
|
748
727
|
self.free_pages = self.free_pages[num_new_pages:]
|
749
|
-
return out_indices
|
750
|
-
|
751
|
-
def clear(self):
|
752
|
-
super().clear()
|
753
|
-
self.free_pages = self.free_pages.to(torch.int32)
|
754
|
-
self.release_pages = self.release_pages.to(torch.int32)
|
728
|
+
return out_indices.int()
|
@@ -71,8 +71,10 @@ class HiRadixCache(RadixCache):
|
|
71
71
|
self.tp_group = tp_cache_group
|
72
72
|
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
73
73
|
self.enable_storage = hicache_storage_backend is not None
|
74
|
-
# todo: customizable storage prefetch threshold
|
74
|
+
# todo: customizable storage prefetch threshold and timeout
|
75
75
|
self.prefetch_threshold = 256
|
76
|
+
self.prefetch_timeout = 3 # seconds
|
77
|
+
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
76
78
|
|
77
79
|
self.load_cache_event = threading.Event()
|
78
80
|
self.cache_controller = HiCacheController(
|
@@ -87,13 +89,6 @@ class HiRadixCache(RadixCache):
|
|
87
89
|
prefetch_threshold=self.prefetch_threshold,
|
88
90
|
)
|
89
91
|
|
90
|
-
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
91
|
-
# todo: customizable storage prefetch timeout
|
92
|
-
self.prefetch_timeout = 3 # seconds
|
93
|
-
logger.info(
|
94
|
-
f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
|
95
|
-
)
|
96
|
-
|
97
92
|
# record the nodes with ongoing write through
|
98
93
|
self.ongoing_write_through = {}
|
99
94
|
# record the node segments with ongoing load back
|
@@ -151,7 +146,7 @@ class HiRadixCache(RadixCache):
|
|
151
146
|
|
152
147
|
def write_backup_storage(self, node: TreeNode):
|
153
148
|
operation_id = self.cache_controller.write_storage(
|
154
|
-
node.host_value, node.key, node.
|
149
|
+
node.host_value, node.key, node.hash_value
|
155
150
|
)
|
156
151
|
self.ongoing_backup[operation_id] = node
|
157
152
|
node.protect_host()
|
@@ -414,18 +409,18 @@ class HiRadixCache(RadixCache):
|
|
414
409
|
group=self.tp_group,
|
415
410
|
)
|
416
411
|
for _ in range(queue_size.item()):
|
417
|
-
ack_id,
|
418
|
-
self.cache_controller.ack_backup_queue.get()
|
419
|
-
)
|
412
|
+
ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
|
420
413
|
host_node = self.ongoing_backup[ack_id]
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
414
|
+
|
415
|
+
if completed_tokens > 0:
|
416
|
+
if completed_tokens < len(host_node.key):
|
417
|
+
# backup is only partially successful, split the node
|
418
|
+
new_node = self._split_node(
|
419
|
+
host_node.key, host_node, completed_tokens
|
420
|
+
)
|
421
|
+
new_node.backuped_storage = True
|
422
|
+
else:
|
423
|
+
host_node.backuped_storage = True
|
429
424
|
host_node.release_host()
|
430
425
|
del self.ongoing_backup[ack_id]
|
431
426
|
|
@@ -471,6 +466,10 @@ class HiRadixCache(RadixCache):
|
|
471
466
|
req_id
|
472
467
|
]
|
473
468
|
|
469
|
+
if operation.host_indices is None:
|
470
|
+
# prefetch has not been issued due to insufficient host memory
|
471
|
+
return True
|
472
|
+
|
474
473
|
if not self.can_terminate_prefetch(operation):
|
475
474
|
return False
|
476
475
|
|
@@ -565,10 +564,6 @@ class HiRadixCache(RadixCache):
|
|
565
564
|
if host_indices is None:
|
566
565
|
self.evict_host(prefetch_length)
|
567
566
|
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
568
|
-
if host_indices is None:
|
569
|
-
last_host_node.release_host()
|
570
|
-
# no sufficient host memory to prefetch
|
571
|
-
return
|
572
567
|
operation = self.cache_controller.prefetch(
|
573
568
|
req_id, host_indices, new_input_tokens, last_hash
|
574
569
|
)
|
@@ -717,6 +712,21 @@ class HiRadixCache(RadixCache):
|
|
717
712
|
node.children[child_key] = new_node
|
718
713
|
self.evictable_size_ += len(value)
|
719
714
|
|
715
|
+
if self.enable_storage:
|
716
|
+
last_hash = node.get_last_hash_value()
|
717
|
+
assert (node == self.root_node) or (
|
718
|
+
last_hash is not None
|
719
|
+
), "Parent node must have a hash value with storage enabled"
|
720
|
+
new_node.hash_value = []
|
721
|
+
for idx in range(0, len(key), self.page_size):
|
722
|
+
new_node.hash_value.append(
|
723
|
+
self.cache_controller.get_hash_str(
|
724
|
+
key[idx : idx + self.page_size],
|
725
|
+
prior_hash=last_hash,
|
726
|
+
)
|
727
|
+
)
|
728
|
+
last_hash = new_node.hash_value[-1]
|
729
|
+
|
720
730
|
if self.cache_controller.write_policy != "write_back":
|
721
731
|
self.inc_hit_count(new_node)
|
722
732
|
return total_prefix_length
|