sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -43,12 +43,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
43
43
  dtype: torch.dtype,
44
44
  device: str,
45
45
  kvcache: KVCache,
46
+ need_sort: bool,
46
47
  ):
47
48
  self.size = size
48
49
  self.page_size = page_size
49
50
  self.dtype = dtype
50
51
  self.device = device
51
52
  self._kvcache = kvcache
53
+ self.need_sort = need_sort
52
54
 
53
55
  self.free_pages = None
54
56
  self.release_pages = None
@@ -79,6 +81,9 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
79
81
  if self.free_group:
80
82
  self.free(torch.cat(self.free_group))
81
83
 
84
+ def estimated_num_new_pages(self, bs, extend_num_tokens):
85
+ return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
86
+
82
87
  def merge_and_sort_free(self):
83
88
  if len(self.release_pages) > 0:
84
89
  self.free_pages = torch.cat((self.free_pages, self.release_pages))
@@ -117,8 +122,15 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
117
122
  class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
118
123
  """An allocator managing the indices to kv cache data."""
119
124
 
120
- def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
121
- super().__init__(size, 1, dtype, device, kvcache)
125
+ def __init__(
126
+ self,
127
+ size: int,
128
+ dtype: torch.dtype,
129
+ device: str,
130
+ kvcache: KVCache,
131
+ need_sort: bool,
132
+ ):
133
+ super().__init__(size, 1, dtype, device, kvcache, need_sort)
122
134
  self.clear()
123
135
 
124
136
  def clear(self):
@@ -135,7 +147,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
135
147
  return len(self.free_pages) + len(self.release_pages)
136
148
 
137
149
  def alloc(self, need_size: int):
138
- if need_size > len(self.free_pages):
150
+ if self.need_sort and need_size > len(self.free_pages):
139
151
  self.merge_and_sort_free()
140
152
  if need_size > len(self.free_pages):
141
153
  return None
@@ -149,7 +161,10 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
149
161
  return
150
162
 
151
163
  if self.is_not_in_free_group:
152
- self.release_pages = torch.cat((self.release_pages, free_index))
164
+ if self.need_sort:
165
+ self.release_pages = torch.cat((self.release_pages, free_index))
166
+ else:
167
+ self.free_pages = torch.cat((self.free_pages, free_index))
153
168
  else:
154
169
  self.free_group.append(free_index)
155
170
 
@@ -170,8 +185,9 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
170
185
  dtype: torch.dtype,
171
186
  device: str,
172
187
  kvcache: SWAKVPool,
188
+ need_sort: bool,
173
189
  ):
174
- super().__init__(size, 1, dtype, device, kvcache)
190
+ super().__init__(size, 1, dtype, device, kvcache, need_sort)
175
191
  assert isinstance(kvcache, SWAKVPool)
176
192
  self._size_full = size
177
193
  self._size_swa = size_swa
@@ -180,12 +196,14 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
180
196
  dtype,
181
197
  device,
182
198
  kvcache.full_kv_pool,
199
+ need_sort,
183
200
  )
184
201
  self.swa_attn_allocator = TokenToKVPoolAllocator(
185
202
  size_swa,
186
203
  dtype,
187
204
  device,
188
205
  kvcache.swa_kv_pool,
206
+ need_sort,
189
207
  )
190
208
  self.full_to_swa_index_mapping = torch.empty(
191
209
  size + size_swa + 1,
@@ -418,8 +436,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
418
436
  dtype: torch.dtype,
419
437
  device: str,
420
438
  kvcache: KVCache,
439
+ need_sort: bool,
421
440
  ):
422
- super().__init__(size, page_size, dtype, device, kvcache)
441
+ super().__init__(size, page_size, dtype, device, kvcache, need_sort)
423
442
  self.num_pages = size // page_size
424
443
  self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
425
444
  self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
@@ -433,7 +452,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
433
452
  ), "The allocation size should be page-aligned"
434
453
 
435
454
  num_pages = need_size // self.page_size
436
- if num_pages > len(self.free_pages):
455
+ if self.need_sort and num_pages > len(self.free_pages):
437
456
  self.merge_and_sort_free()
438
457
  if num_pages > len(self.free_pages):
439
458
  return None
@@ -460,18 +479,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
460
479
  (last_loc + 1) % self.page_size == prefix_lens % self.page_size
461
480
  )
462
481
 
463
- estimated_num_new_pages = (
464
- (
465
- (seq_lens + self.page_size - 1) // self.page_size
466
- - (prefix_lens + self.page_size - 1) // self.page_size
467
- )
468
- .sum()
469
- .item()
470
- )
471
- if estimated_num_new_pages > len(self.free_pages):
482
+ bs = len(prefix_lens)
483
+ if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
484
+ self.free_pages
485
+ ):
472
486
  self.merge_and_sort_free()
473
487
 
474
- bs = len(prefix_lens)
475
488
  out_indices = torch.empty(
476
489
  (extend_num_tokens,), dtype=torch.int64, device=self.device
477
490
  )
@@ -508,18 +521,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
508
521
  (last_loc + 2) % self.page_size == seq_lens % self.page_size
509
522
  )
510
523
 
511
- estimated_num_new_pages = (
512
- (
513
- (seq_lens + self.page_size - 1) // self.page_size
514
- - (seq_lens - 1 + self.page_size - 1) // self.page_size
515
- )
516
- .sum()
517
- .item()
518
- )
519
- if estimated_num_new_pages > len(self.free_pages):
524
+ bs = len(seq_lens)
525
+ if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
526
+ self.free_pages
527
+ ):
520
528
  self.merge_and_sort_free()
521
529
 
522
- bs = len(seq_lens)
523
530
  out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
524
531
  alloc_decode_kernel[(bs,)](
525
532
  seq_lens,
@@ -547,7 +554,10 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
547
554
 
548
555
  if self.is_not_in_free_group:
549
556
  free_page_indices = torch.unique(free_index // self.page_size)
550
- self.release_pages = torch.cat((free_page_indices, self.release_pages))
557
+ if self.need_sort:
558
+ self.release_pages = torch.cat((free_page_indices, self.release_pages))
559
+ else:
560
+ self.free_pages = torch.cat((free_page_indices, self.free_pages))
551
561
  else:
552
562
  self.free_group.append(free_index)
553
563
 
@@ -622,27 +632,6 @@ def alloc_extend_kernel_ascend(
622
632
  out_indices[end_pos[i] - num3 : end_pos[i]] = (
623
633
  free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
624
634
  ).view(-1)
625
- return num_new_pages
626
-
627
-
628
- def alloc_decode_kernel_ascend(
629
- seq_lens,
630
- last_loc,
631
- free_pages,
632
- out_indices,
633
- page_size,
634
- ):
635
- num_new_pages = (seq_lens + page_size - 1) // page_size - (
636
- seq_lens - 1 + page_size - 1
637
- ) // page_size
638
- end_new_pages = torch.cumsum(num_new_pages, 0)
639
- start_new_pages = end_new_pages - num_new_pages
640
- for i in range(len(seq_lens)):
641
- if num_new_pages[i]:
642
- out_indices[i] = free_pages[start_new_pages[i]] * page_size
643
- else:
644
- out_indices[i] = last_loc[i] + 1
645
- return num_new_pages
646
635
 
647
636
 
648
637
  class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
@@ -654,9 +643,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
654
643
  dtype: torch.dtype,
655
644
  device: str,
656
645
  kvcache: KVCache,
646
+ need_sort: bool,
657
647
  ):
658
- super().__init__(size, page_size, dtype, device, kvcache)
659
- self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
648
+ super().__init__(size, page_size, dtype, device, kvcache, need_sort)
660
649
 
661
650
  def alloc_extend(
662
651
  self,
@@ -678,15 +667,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
678
667
  .sum()
679
668
  .item()
680
669
  )
681
- if estimated_num_new_pages > len(self.free_pages):
670
+ if self.need_sort and estimated_num_new_pages > len(self.free_pages):
682
671
  self.merge_and_sort_free()
683
672
 
684
- bs = len(prefix_lens)
673
+ if estimated_num_new_pages > len(self.free_pages):
674
+ return None
675
+
685
676
  out_indices = torch.empty(
686
677
  (extend_num_tokens,), dtype=torch.int32, device=self.device
687
678
  )
688
679
 
689
- self.ret_values = alloc_extend_kernel_ascend(
680
+ alloc_extend_kernel_ascend(
690
681
  prefix_lens,
691
682
  seq_lens,
692
683
  last_loc,
@@ -699,11 +690,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
699
690
  if self.debug_mode:
700
691
  assert len(torch.unique(out_indices)) == len(out_indices)
701
692
 
702
- num_new_pages = self.ret_values.sum()
703
- if num_new_pages > len(self.free_pages):
704
- return None
705
-
706
- self.free_pages = self.free_pages[num_new_pages:]
693
+ self.free_pages = self.free_pages[estimated_num_new_pages:]
707
694
  return out_indices
708
695
 
709
696
  def alloc_decode(
@@ -716,39 +703,26 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
716
703
  (last_loc + 2) % self.page_size == seq_lens % self.page_size
717
704
  )
718
705
 
719
- estimated_num_new_pages = (
720
- (
721
- (seq_lens + self.page_size - 1) // self.page_size
722
- - (seq_lens - 1 + self.page_size - 1) // self.page_size
723
- )
724
- .sum()
725
- .item()
726
- )
727
- if estimated_num_new_pages > len(self.free_pages):
706
+ need_new_pages = (seq_lens % self.page_size == 1).int()
707
+ num_new_pages = need_new_pages.sum().item()
708
+
709
+ if num_new_pages > len(self.free_pages):
728
710
  self.merge_and_sort_free()
729
711
 
730
- bs = len(seq_lens)
731
- out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
712
+ if num_new_pages > len(self.free_pages):
713
+ return None
732
714
 
733
- self.ret_values = alloc_decode_kernel_ascend(
734
- seq_lens,
735
- last_loc,
736
- self.free_pages,
737
- out_indices,
738
- self.page_size,
739
- )
715
+ end_new_pages = torch.cumsum(need_new_pages, 0)
716
+ start_new_pages = end_new_pages - need_new_pages
717
+ if num_new_pages == 0:
718
+ out_indices = last_loc + 1
719
+ else:
720
+ out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
721
+ start_new_pages
722
+ ] * self.page_size * need_new_pages
740
723
 
741
724
  if self.debug_mode:
742
725
  assert len(torch.unique(out_indices)) == len(out_indices)
743
726
 
744
- num_new_pages = self.ret_values.sum()
745
- if num_new_pages > len(self.free_pages):
746
- return None
747
-
748
727
  self.free_pages = self.free_pages[num_new_pages:]
749
- return out_indices
750
-
751
- def clear(self):
752
- super().clear()
753
- self.free_pages = self.free_pages.to(torch.int32)
754
- self.release_pages = self.release_pages.to(torch.int32)
728
+ return out_indices.int()
@@ -15,7 +15,7 @@ from sglang.srt.distributed import (
15
15
  )
16
16
 
17
17
 
18
- def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
18
+ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
19
19
  hasher = hashlib.sha256()
20
20
 
21
21
  if prior_hash:
@@ -2,11 +2,12 @@ import heapq
2
2
  import logging
3
3
  import threading
4
4
  import time
5
+ from queue import Queue
5
6
  from typing import List, Optional
6
7
 
7
8
  import torch
8
9
 
9
- from sglang.srt.managers.cache_controller import HiCacheController
10
+ from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
10
11
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
11
12
  from sglang.srt.mem_cache.base_prefix_cache import MatchResult
12
13
  from sglang.srt.mem_cache.memory_pool import (
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
37
38
  hicache_io_backend: str,
38
39
  hicache_mem_layout: str,
39
40
  hicache_storage_backend: Optional[str] = None,
41
+ hicache_storage_prefetch_policy: Optional[str] = "best_effort",
40
42
  ):
41
43
 
42
44
  if hicache_io_backend == "direct":
@@ -69,8 +71,10 @@ class HiRadixCache(RadixCache):
69
71
  self.tp_group = tp_cache_group
70
72
  self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
71
73
  self.enable_storage = hicache_storage_backend is not None
72
- # todo: customizable storage prefetch threshold
74
+ # todo: customizable storage prefetch threshold and timeout
73
75
  self.prefetch_threshold = 256
76
+ self.prefetch_timeout = 3 # seconds
77
+ self.prefetch_stop_policy = hicache_storage_prefetch_policy
74
78
 
75
79
  self.load_cache_event = threading.Event()
76
80
  self.cache_controller = HiCacheController(
@@ -142,7 +146,7 @@ class HiRadixCache(RadixCache):
142
146
 
143
147
  def write_backup_storage(self, node: TreeNode):
144
148
  operation_id = self.cache_controller.write_storage(
145
- node.host_value, node.key, node.parent.get_last_hash_value()
149
+ node.host_value, node.key, node.hash_value
146
150
  )
147
151
  self.ongoing_backup[operation_id] = node
148
152
  node.protect_host()
@@ -385,9 +389,10 @@ class HiRadixCache(RadixCache):
385
389
  for _ in range(queue_size.item()):
386
390
  req_id = self.cache_controller.prefetch_revoke_queue.get()
387
391
  if req_id in self.ongoing_prefetch:
388
- last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
392
+ last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
389
393
  last_host_node.release_host()
390
394
  del self.ongoing_prefetch[req_id]
395
+ self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
391
396
  else:
392
397
  # the revoked operation already got terminated
393
398
  pass
@@ -404,25 +409,56 @@ class HiRadixCache(RadixCache):
404
409
  group=self.tp_group,
405
410
  )
406
411
  for _ in range(queue_size.item()):
407
- ack_id, hash_value, completed_tokens = (
408
- self.cache_controller.ack_backup_queue.get()
409
- )
412
+ ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
410
413
  host_node = self.ongoing_backup[ack_id]
411
- if completed_tokens == 0:
412
- host_node.hash_value = None
413
- elif completed_tokens < len(host_node.key):
414
- # backup is only partially successful, split the node
415
- new_node = self._split_node(host_node.key, host_node, completed_tokens)
416
- new_node.hash_value = hash_value
417
- else:
418
- host_node.hash_value = hash_value
414
+
415
+ if completed_tokens > 0:
416
+ if completed_tokens < len(host_node.key):
417
+ # backup is only partially successful, split the node
418
+ new_node = self._split_node(
419
+ host_node.key, host_node, completed_tokens
420
+ )
421
+ new_node.backuped_storage = True
422
+ else:
423
+ host_node.backuped_storage = True
419
424
  host_node.release_host()
420
425
  del self.ongoing_backup[ack_id]
421
426
 
422
- def check_prefetch_progress(self, req_id: str):
427
+ def can_terminate_prefetch(self, operation: PrefetchOperation):
428
+ can_terminate = True
429
+
430
+ if self.prefetch_stop_policy == "best_effort":
431
+ return can_terminate
432
+
433
+ completed = (
434
+ operation.completed_tokens == len(operation.hash_value) * self.page_size
435
+ )
436
+
437
+ if self.prefetch_stop_policy == "wait_complete":
438
+ can_terminate = completed
439
+ elif self.prefetch_stop_policy == "timeout":
440
+ can_terminate = completed or (
441
+ time.monotonic() - operation.start_time > self.prefetch_timeout
442
+ )
443
+ else:
444
+ # unknown prefetch stop policy, just return True
445
+ return True
446
+
447
+ if self.tp_world_size > 1:
448
+ can_terminate = torch.tensor(can_terminate, dtype=torch.int)
449
+ torch.distributed.all_reduce(
450
+ can_terminate,
451
+ op=torch.distributed.ReduceOp.MIN,
452
+ group=self.tp_group,
453
+ )
454
+ can_terminate = bool(can_terminate.item())
455
+
456
+ return can_terminate
457
+
458
+ def check_prefetch_progress(self, req_id: str) -> bool:
423
459
  if req_id not in self.ongoing_prefetch:
424
460
  # there is no ongoing prefetch for this request or it has been revoked
425
- return
461
+ return True
426
462
 
427
463
  # todo: more policies for prefetch progress such as timeout
428
464
  # the current policy is to prefetch with best effort and terminate when queuing is over
@@ -430,13 +466,20 @@ class HiRadixCache(RadixCache):
430
466
  req_id
431
467
  ]
432
468
 
469
+ if operation.host_indices is None:
470
+ # prefetch has not been issued due to insufficient host memory
471
+ return True
472
+
473
+ if not self.can_terminate_prefetch(operation):
474
+ return False
475
+
433
476
  completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
434
477
  operation
435
478
  )
436
479
  logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
437
480
 
438
481
  min_completed_tokens = completed_tokens
439
- if self.tp_world_size > 1:
482
+ if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
440
483
  # synchrnoize TP workers to make the same update to hiradix cache
441
484
  completed_tokens_tensor = torch.tensor(
442
485
  min_completed_tokens, dtype=torch.int
@@ -464,6 +507,9 @@ class HiRadixCache(RadixCache):
464
507
  )
465
508
  last_host_node.release_host()
466
509
  del self.ongoing_prefetch[req_id]
510
+ self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
511
+
512
+ return True
467
513
 
468
514
  def match_prefix(self, key: List[int], **kwargs):
469
515
  empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
@@ -518,10 +564,6 @@ class HiRadixCache(RadixCache):
518
564
  if host_indices is None:
519
565
  self.evict_host(prefetch_length)
520
566
  host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
521
- if host_indices is None:
522
- last_host_node.release_host()
523
- # no sufficient host memory to prefetch
524
- return
525
567
  operation = self.cache_controller.prefetch(
526
568
  req_id, host_indices, new_input_tokens, last_hash
527
569
  )
@@ -531,6 +573,7 @@ class HiRadixCache(RadixCache):
531
573
  host_indices,
532
574
  operation,
533
575
  )
576
+ self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
534
577
 
535
578
  def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
536
579
  node.last_access_time = time.monotonic()
@@ -669,6 +712,21 @@ class HiRadixCache(RadixCache):
669
712
  node.children[child_key] = new_node
670
713
  self.evictable_size_ += len(value)
671
714
 
715
+ if self.enable_storage:
716
+ last_hash = node.get_last_hash_value()
717
+ assert (node == self.root_node) or (
718
+ last_hash is not None
719
+ ), "Parent node must have a hash value with storage enabled"
720
+ new_node.hash_value = []
721
+ for idx in range(0, len(key), self.page_size):
722
+ new_node.hash_value.append(
723
+ self.cache_controller.get_hash_str(
724
+ key[idx : idx + self.page_size],
725
+ prior_hash=last_hash,
726
+ )
727
+ )
728
+ last_hash = new_node.hash_value[-1]
729
+
672
730
  if self.cache_controller.write_policy != "write_back":
673
731
  self.inc_hit_count(new_node)
674
732
  return total_prefix_length