sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -0
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +7 -7
  6. sglang/srt/disaggregation/decode.py +8 -3
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +4 -5
  14. sglang/srt/entrypoints/openai/protocol.py +0 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +59 -265
  16. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  17. sglang/srt/function_call/ebnf_composer.py +1 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  20. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  21. sglang/srt/function_call/kimik2_detector.py +3 -3
  22. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  23. sglang/srt/jinja_template_utils.py +6 -0
  24. sglang/srt/layers/attention/aiter_backend.py +370 -107
  25. sglang/srt/layers/attention/ascend_backend.py +3 -0
  26. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  27. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  28. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  29. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  30. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  31. sglang/srt/layers/attention/vision.py +9 -1
  32. sglang/srt/layers/attention/wave_backend.py +627 -0
  33. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  34. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  35. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  36. sglang/srt/layers/communicator.py +8 -10
  37. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  38. sglang/srt/layers/linear.py +1 -0
  39. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  41. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  42. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  43. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  46. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  47. sglang/srt/layers/moe/topk.py +4 -1
  48. sglang/srt/layers/quantization/__init__.py +5 -3
  49. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  50. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  51. sglang/srt/layers/quantization/modelopt_quant.py +6 -11
  52. sglang/srt/layers/quantization/mxfp4.py +4 -1
  53. sglang/srt/layers/quantization/w4afp8.py +20 -11
  54. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  55. sglang/srt/layers/rotary_embedding.py +281 -2
  56. sglang/srt/lora/backend/base_backend.py +3 -23
  57. sglang/srt/lora/layers.py +60 -114
  58. sglang/srt/lora/lora.py +17 -62
  59. sglang/srt/lora/lora_manager.py +12 -48
  60. sglang/srt/lora/lora_registry.py +20 -9
  61. sglang/srt/lora/mem_pool.py +20 -63
  62. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  63. sglang/srt/lora/utils.py +25 -58
  64. sglang/srt/managers/cache_controller.py +21 -29
  65. sglang/srt/managers/detokenizer_manager.py +1 -1
  66. sglang/srt/managers/io_struct.py +6 -6
  67. sglang/srt/managers/mm_utils.py +1 -2
  68. sglang/srt/managers/multimodal_processor.py +1 -1
  69. sglang/srt/managers/schedule_batch.py +35 -20
  70. sglang/srt/managers/schedule_policy.py +6 -6
  71. sglang/srt/managers/scheduler.py +15 -7
  72. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  73. sglang/srt/managers/tokenizer_manager.py +25 -26
  74. sglang/srt/mem_cache/allocator.py +61 -87
  75. sglang/srt/mem_cache/hicache_storage.py +1 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  77. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  78. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  79. sglang/srt/mem_cache/radix_cache.py +2 -5
  80. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  81. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  82. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  83. sglang/srt/model_executor/cuda_graph_runner.py +22 -3
  84. sglang/srt/model_executor/forward_batch_info.py +26 -5
  85. sglang/srt/model_executor/model_runner.py +129 -35
  86. sglang/srt/model_loader/loader.py +18 -6
  87. sglang/srt/models/deepseek_v2.py +74 -35
  88. sglang/srt/models/gemma2.py +0 -34
  89. sglang/srt/models/gemma3n_mm.py +8 -9
  90. sglang/srt/models/glm4.py +6 -0
  91. sglang/srt/models/glm4_moe.py +9 -9
  92. sglang/srt/models/glm4v.py +589 -0
  93. sglang/srt/models/glm4v_moe.py +400 -0
  94. sglang/srt/models/gpt_oss.py +136 -19
  95. sglang/srt/models/granite.py +0 -25
  96. sglang/srt/models/llama.py +0 -25
  97. sglang/srt/models/llama4.py +1 -1
  98. sglang/srt/models/qwen2_5_vl.py +7 -3
  99. sglang/srt/models/qwen2_audio.py +10 -9
  100. sglang/srt/models/qwen3.py +0 -24
  101. sglang/srt/models/registry.py +1 -1
  102. sglang/srt/models/torch_native_llama.py +0 -24
  103. sglang/srt/multimodal/processors/base_processor.py +23 -13
  104. sglang/srt/multimodal/processors/glm4v.py +132 -0
  105. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  106. sglang/srt/reasoning_parser.py +316 -0
  107. sglang/srt/server_args.py +115 -139
  108. sglang/srt/speculative/eagle_worker.py +16 -0
  109. sglang/srt/two_batch_overlap.py +12 -4
  110. sglang/srt/utils.py +3 -3
  111. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  112. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  113. sglang/test/doc_patch.py +59 -0
  114. sglang/test/few_shot_gsm8k.py +1 -1
  115. sglang/test/few_shot_gsm8k_engine.py +1 -1
  116. sglang/test/run_eval.py +4 -1
  117. sglang/test/simple_eval_common.py +6 -0
  118. sglang/test/simple_eval_gpqa.py +2 -0
  119. sglang/test/test_fp4_moe.py +118 -36
  120. sglang/utils.py +1 -1
  121. sglang/version.py +1 -1
  122. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
  123. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
  124. sglang/lang/backend/__init__.py +0 -0
  125. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  126. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  127. /sglang/{api.py → lang/api.py} +0 -0
  128. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  129. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -43,12 +43,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
43
43
  dtype: torch.dtype,
44
44
  device: str,
45
45
  kvcache: KVCache,
46
+ need_sort: bool,
46
47
  ):
47
48
  self.size = size
48
49
  self.page_size = page_size
49
50
  self.dtype = dtype
50
51
  self.device = device
51
52
  self._kvcache = kvcache
53
+ self.need_sort = need_sort
52
54
 
53
55
  self.free_pages = None
54
56
  self.release_pages = None
@@ -79,6 +81,9 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
79
81
  if self.free_group:
80
82
  self.free(torch.cat(self.free_group))
81
83
 
84
+ def estimated_num_new_pages(self, bs, extend_num_tokens):
85
+ return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
86
+
82
87
  def merge_and_sort_free(self):
83
88
  if len(self.release_pages) > 0:
84
89
  self.free_pages = torch.cat((self.free_pages, self.release_pages))
@@ -117,8 +122,15 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
117
122
  class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
118
123
  """An allocator managing the indices to kv cache data."""
119
124
 
120
- def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
121
- super().__init__(size, 1, dtype, device, kvcache)
125
+ def __init__(
126
+ self,
127
+ size: int,
128
+ dtype: torch.dtype,
129
+ device: str,
130
+ kvcache: KVCache,
131
+ need_sort: bool,
132
+ ):
133
+ super().__init__(size, 1, dtype, device, kvcache, need_sort)
122
134
  self.clear()
123
135
 
124
136
  def clear(self):
@@ -135,7 +147,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
135
147
  return len(self.free_pages) + len(self.release_pages)
136
148
 
137
149
  def alloc(self, need_size: int):
138
- if need_size > len(self.free_pages):
150
+ if self.need_sort and need_size > len(self.free_pages):
139
151
  self.merge_and_sort_free()
140
152
  if need_size > len(self.free_pages):
141
153
  return None
@@ -149,7 +161,10 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
149
161
  return
150
162
 
151
163
  if self.is_not_in_free_group:
152
- self.release_pages = torch.cat((self.release_pages, free_index))
164
+ if self.need_sort:
165
+ self.release_pages = torch.cat((self.release_pages, free_index))
166
+ else:
167
+ self.free_pages = torch.cat((self.free_pages, free_index))
153
168
  else:
154
169
  self.free_group.append(free_index)
155
170
 
@@ -170,8 +185,9 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
170
185
  dtype: torch.dtype,
171
186
  device: str,
172
187
  kvcache: SWAKVPool,
188
+ need_sort: bool,
173
189
  ):
174
- super().__init__(size, 1, dtype, device, kvcache)
190
+ super().__init__(size, 1, dtype, device, kvcache, need_sort)
175
191
  assert isinstance(kvcache, SWAKVPool)
176
192
  self._size_full = size
177
193
  self._size_swa = size_swa
@@ -180,12 +196,14 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
180
196
  dtype,
181
197
  device,
182
198
  kvcache.full_kv_pool,
199
+ need_sort,
183
200
  )
184
201
  self.swa_attn_allocator = TokenToKVPoolAllocator(
185
202
  size_swa,
186
203
  dtype,
187
204
  device,
188
205
  kvcache.swa_kv_pool,
206
+ need_sort,
189
207
  )
190
208
  self.full_to_swa_index_mapping = torch.empty(
191
209
  size + size_swa + 1,
@@ -418,8 +436,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
418
436
  dtype: torch.dtype,
419
437
  device: str,
420
438
  kvcache: KVCache,
439
+ need_sort: bool,
421
440
  ):
422
- super().__init__(size, page_size, dtype, device, kvcache)
441
+ super().__init__(size, page_size, dtype, device, kvcache, need_sort)
423
442
  self.num_pages = size // page_size
424
443
  self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
425
444
  self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
@@ -433,7 +452,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
433
452
  ), "The allocation size should be page-aligned"
434
453
 
435
454
  num_pages = need_size // self.page_size
436
- if num_pages > len(self.free_pages):
455
+ if self.need_sort and num_pages > len(self.free_pages):
437
456
  self.merge_and_sort_free()
438
457
  if num_pages > len(self.free_pages):
439
458
  return None
@@ -460,18 +479,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
460
479
  (last_loc + 1) % self.page_size == prefix_lens % self.page_size
461
480
  )
462
481
 
463
- estimated_num_new_pages = (
464
- (
465
- (seq_lens + self.page_size - 1) // self.page_size
466
- - (prefix_lens + self.page_size - 1) // self.page_size
467
- )
468
- .sum()
469
- .item()
470
- )
471
- if estimated_num_new_pages > len(self.free_pages):
482
+ bs = len(prefix_lens)
483
+ if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
484
+ self.free_pages
485
+ ):
472
486
  self.merge_and_sort_free()
473
487
 
474
- bs = len(prefix_lens)
475
488
  out_indices = torch.empty(
476
489
  (extend_num_tokens,), dtype=torch.int64, device=self.device
477
490
  )
@@ -508,18 +521,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
508
521
  (last_loc + 2) % self.page_size == seq_lens % self.page_size
509
522
  )
510
523
 
511
- estimated_num_new_pages = (
512
- (
513
- (seq_lens + self.page_size - 1) // self.page_size
514
- - (seq_lens - 1 + self.page_size - 1) // self.page_size
515
- )
516
- .sum()
517
- .item()
518
- )
519
- if estimated_num_new_pages > len(self.free_pages):
524
+ bs = len(seq_lens)
525
+ if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
526
+ self.free_pages
527
+ ):
520
528
  self.merge_and_sort_free()
521
529
 
522
- bs = len(seq_lens)
523
530
  out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
524
531
  alloc_decode_kernel[(bs,)](
525
532
  seq_lens,
@@ -547,7 +554,10 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
547
554
 
548
555
  if self.is_not_in_free_group:
549
556
  free_page_indices = torch.unique(free_index // self.page_size)
550
- self.release_pages = torch.cat((free_page_indices, self.release_pages))
557
+ if self.need_sort:
558
+ self.release_pages = torch.cat((free_page_indices, self.release_pages))
559
+ else:
560
+ self.free_pages = torch.cat((free_page_indices, self.free_pages))
551
561
  else:
552
562
  self.free_group.append(free_index)
553
563
 
@@ -622,27 +632,6 @@ def alloc_extend_kernel_ascend(
622
632
  out_indices[end_pos[i] - num3 : end_pos[i]] = (
623
633
  free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
624
634
  ).view(-1)
625
- return num_new_pages
626
-
627
-
628
- def alloc_decode_kernel_ascend(
629
- seq_lens,
630
- last_loc,
631
- free_pages,
632
- out_indices,
633
- page_size,
634
- ):
635
- num_new_pages = (seq_lens + page_size - 1) // page_size - (
636
- seq_lens - 1 + page_size - 1
637
- ) // page_size
638
- end_new_pages = torch.cumsum(num_new_pages, 0)
639
- start_new_pages = end_new_pages - num_new_pages
640
- for i in range(len(seq_lens)):
641
- if num_new_pages[i]:
642
- out_indices[i] = free_pages[start_new_pages[i]] * page_size
643
- else:
644
- out_indices[i] = last_loc[i] + 1
645
- return num_new_pages
646
635
 
647
636
 
648
637
  class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
@@ -654,9 +643,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
654
643
  dtype: torch.dtype,
655
644
  device: str,
656
645
  kvcache: KVCache,
646
+ need_sort: bool,
657
647
  ):
658
- super().__init__(size, page_size, dtype, device, kvcache)
659
- self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
648
+ super().__init__(size, page_size, dtype, device, kvcache, need_sort)
660
649
 
661
650
  def alloc_extend(
662
651
  self,
@@ -678,15 +667,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
678
667
  .sum()
679
668
  .item()
680
669
  )
681
- if estimated_num_new_pages > len(self.free_pages):
670
+ if self.need_sort and estimated_num_new_pages > len(self.free_pages):
682
671
  self.merge_and_sort_free()
683
672
 
684
- bs = len(prefix_lens)
673
+ if estimated_num_new_pages > len(self.free_pages):
674
+ return None
675
+
685
676
  out_indices = torch.empty(
686
677
  (extend_num_tokens,), dtype=torch.int32, device=self.device
687
678
  )
688
679
 
689
- self.ret_values = alloc_extend_kernel_ascend(
680
+ alloc_extend_kernel_ascend(
690
681
  prefix_lens,
691
682
  seq_lens,
692
683
  last_loc,
@@ -699,11 +690,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
699
690
  if self.debug_mode:
700
691
  assert len(torch.unique(out_indices)) == len(out_indices)
701
692
 
702
- num_new_pages = self.ret_values.sum()
703
- if num_new_pages > len(self.free_pages):
704
- return None
705
-
706
- self.free_pages = self.free_pages[num_new_pages:]
693
+ self.free_pages = self.free_pages[estimated_num_new_pages:]
707
694
  return out_indices
708
695
 
709
696
  def alloc_decode(
@@ -716,39 +703,26 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
716
703
  (last_loc + 2) % self.page_size == seq_lens % self.page_size
717
704
  )
718
705
 
719
- estimated_num_new_pages = (
720
- (
721
- (seq_lens + self.page_size - 1) // self.page_size
722
- - (seq_lens - 1 + self.page_size - 1) // self.page_size
723
- )
724
- .sum()
725
- .item()
726
- )
727
- if estimated_num_new_pages > len(self.free_pages):
706
+ need_new_pages = (seq_lens % self.page_size == 1).int()
707
+ num_new_pages = need_new_pages.sum().item()
708
+
709
+ if num_new_pages > len(self.free_pages):
728
710
  self.merge_and_sort_free()
729
711
 
730
- bs = len(seq_lens)
731
- out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
712
+ if num_new_pages > len(self.free_pages):
713
+ return None
732
714
 
733
- self.ret_values = alloc_decode_kernel_ascend(
734
- seq_lens,
735
- last_loc,
736
- self.free_pages,
737
- out_indices,
738
- self.page_size,
739
- )
715
+ end_new_pages = torch.cumsum(need_new_pages, 0)
716
+ start_new_pages = end_new_pages - need_new_pages
717
+ if num_new_pages == 0:
718
+ out_indices = last_loc + 1
719
+ else:
720
+ out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
721
+ start_new_pages
722
+ ] * self.page_size * need_new_pages
740
723
 
741
724
  if self.debug_mode:
742
725
  assert len(torch.unique(out_indices)) == len(out_indices)
743
726
 
744
- num_new_pages = self.ret_values.sum()
745
- if num_new_pages > len(self.free_pages):
746
- return None
747
-
748
727
  self.free_pages = self.free_pages[num_new_pages:]
749
- return out_indices
750
-
751
- def clear(self):
752
- super().clear()
753
- self.free_pages = self.free_pages.to(torch.int32)
754
- self.release_pages = self.release_pages.to(torch.int32)
728
+ return out_indices.int()
@@ -15,7 +15,7 @@ from sglang.srt.distributed import (
15
15
  )
16
16
 
17
17
 
18
- def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
18
+ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
19
19
  hasher = hashlib.sha256()
20
20
 
21
21
  if prior_hash:
@@ -71,8 +71,10 @@ class HiRadixCache(RadixCache):
71
71
  self.tp_group = tp_cache_group
72
72
  self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
73
73
  self.enable_storage = hicache_storage_backend is not None
74
- # todo: customizable storage prefetch threshold
74
+ # todo: customizable storage prefetch threshold and timeout
75
75
  self.prefetch_threshold = 256
76
+ self.prefetch_timeout = 3 # seconds
77
+ self.prefetch_stop_policy = hicache_storage_prefetch_policy
76
78
 
77
79
  self.load_cache_event = threading.Event()
78
80
  self.cache_controller = HiCacheController(
@@ -87,13 +89,6 @@ class HiRadixCache(RadixCache):
87
89
  prefetch_threshold=self.prefetch_threshold,
88
90
  )
89
91
 
90
- self.prefetch_stop_policy = hicache_storage_prefetch_policy
91
- # todo: customizable storage prefetch timeout
92
- self.prefetch_timeout = 3 # seconds
93
- logger.info(
94
- f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
95
- )
96
-
97
92
  # record the nodes with ongoing write through
98
93
  self.ongoing_write_through = {}
99
94
  # record the node segments with ongoing load back
@@ -151,7 +146,7 @@ class HiRadixCache(RadixCache):
151
146
 
152
147
  def write_backup_storage(self, node: TreeNode):
153
148
  operation_id = self.cache_controller.write_storage(
154
- node.host_value, node.key, node.parent.get_last_hash_value()
149
+ node.host_value, node.key, node.hash_value
155
150
  )
156
151
  self.ongoing_backup[operation_id] = node
157
152
  node.protect_host()
@@ -414,18 +409,18 @@ class HiRadixCache(RadixCache):
414
409
  group=self.tp_group,
415
410
  )
416
411
  for _ in range(queue_size.item()):
417
- ack_id, hash_value, completed_tokens = (
418
- self.cache_controller.ack_backup_queue.get()
419
- )
412
+ ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
420
413
  host_node = self.ongoing_backup[ack_id]
421
- if completed_tokens == 0:
422
- host_node.hash_value = None
423
- elif completed_tokens < len(host_node.key):
424
- # backup is only partially successful, split the node
425
- new_node = self._split_node(host_node.key, host_node, completed_tokens)
426
- new_node.hash_value = hash_value
427
- else:
428
- host_node.hash_value = hash_value
414
+
415
+ if completed_tokens > 0:
416
+ if completed_tokens < len(host_node.key):
417
+ # backup is only partially successful, split the node
418
+ new_node = self._split_node(
419
+ host_node.key, host_node, completed_tokens
420
+ )
421
+ new_node.backuped_storage = True
422
+ else:
423
+ host_node.backuped_storage = True
429
424
  host_node.release_host()
430
425
  del self.ongoing_backup[ack_id]
431
426
 
@@ -471,6 +466,10 @@ class HiRadixCache(RadixCache):
471
466
  req_id
472
467
  ]
473
468
 
469
+ if operation.host_indices is None:
470
+ # prefetch has not been issued due to insufficient host memory
471
+ return True
472
+
474
473
  if not self.can_terminate_prefetch(operation):
475
474
  return False
476
475
 
@@ -565,10 +564,6 @@ class HiRadixCache(RadixCache):
565
564
  if host_indices is None:
566
565
  self.evict_host(prefetch_length)
567
566
  host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
568
- if host_indices is None:
569
- last_host_node.release_host()
570
- # no sufficient host memory to prefetch
571
- return
572
567
  operation = self.cache_controller.prefetch(
573
568
  req_id, host_indices, new_input_tokens, last_hash
574
569
  )
@@ -717,6 +712,21 @@ class HiRadixCache(RadixCache):
717
712
  node.children[child_key] = new_node
718
713
  self.evictable_size_ += len(value)
719
714
 
715
+ if self.enable_storage:
716
+ last_hash = node.get_last_hash_value()
717
+ assert (node == self.root_node) or (
718
+ last_hash is not None
719
+ ), "Parent node must have a hash value with storage enabled"
720
+ new_node.hash_value = []
721
+ for idx in range(0, len(key), self.page_size):
722
+ new_node.hash_value.append(
723
+ self.cache_controller.get_hash_str(
724
+ key[idx : idx + self.page_size],
725
+ prior_hash=last_hash,
726
+ )
727
+ )
728
+ last_hash = new_node.hash_value[-1]
729
+
720
730
  if self.cache_controller.write_policy != "write_back":
721
731
  self.inc_hit_count(new_node)
722
732
  return total_prefix_length