sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,6 @@ Page-aligned memory pool.
20
20
  """
21
21
 
22
22
  import abc
23
- import weakref
24
23
  from typing import TYPE_CHECKING
25
24
 
26
25
  import torch
@@ -81,9 +80,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
81
80
  if self.free_group:
82
81
  self.free(torch.cat(self.free_group))
83
82
 
84
- def estimated_num_new_pages(self, bs, extend_num_tokens):
85
- return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
86
-
87
83
  def merge_and_sort_free(self):
88
84
  if len(self.release_pages) > 0:
89
85
  self.free_pages = torch.cat((self.free_pages, self.release_pages))
@@ -149,6 +145,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
149
145
  def alloc(self, need_size: int):
150
146
  if self.need_sort and need_size > len(self.free_pages):
151
147
  self.merge_and_sort_free()
148
+
152
149
  if need_size > len(self.free_pages):
153
150
  return None
154
151
 
@@ -442,6 +439,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
442
439
  self.num_pages = size // page_size
443
440
  self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
444
441
  self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
442
+ self.seen_max_num_extend_tokens_next_power_of_2 = 1
445
443
  self.clear()
446
444
 
447
445
  def alloc(self, need_size: int):
@@ -479,8 +477,13 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
479
477
  (last_loc + 1) % self.page_size == prefix_lens % self.page_size
480
478
  )
481
479
 
480
+ self.seen_max_num_extend_tokens_next_power_of_2 = max(
481
+ self.seen_max_num_extend_tokens_next_power_of_2,
482
+ next_power_of_2(extend_num_tokens),
483
+ )
484
+
482
485
  bs = len(prefix_lens)
483
- if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
486
+ if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
484
487
  self.free_pages
485
488
  ):
486
489
  self.merge_and_sort_free()
@@ -497,7 +500,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
497
500
  self.ret_values,
498
501
  next_power_of_2(bs),
499
502
  self.page_size,
500
- next_power_of_2(extend_num_tokens),
503
+ self.seen_max_num_extend_tokens_next_power_of_2,
501
504
  )
502
505
 
503
506
  if self.debug_mode:
@@ -522,9 +525,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
522
525
  )
523
526
 
524
527
  bs = len(seq_lens)
525
- if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
526
- self.free_pages
527
- ):
528
+ if self.need_sort and bs > len(self.free_pages):
528
529
  self.merge_and_sort_free()
529
530
 
530
531
  out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
@@ -578,151 +579,3 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
578
579
 
579
580
  def load_cpu_copy(self, kv_cache_cpu, indices):
580
581
  return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
581
-
582
-
583
- def alloc_extend_kernel_ascend(
584
- prefix_lens,
585
- seq_lens,
586
- last_loc,
587
- free_pages,
588
- out_indices,
589
- page_size,
590
- device,
591
- ):
592
- extend_lens = seq_lens - prefix_lens
593
- end_pos = torch.cumsum(extend_lens, 0)
594
- start_pos = end_pos - extend_lens
595
- num_new_pages = (seq_lens + page_size - 1) // page_size - (
596
- prefix_lens + page_size - 1
597
- ) // page_size
598
- num_full_new_pages = (seq_lens) // page_size - (
599
- prefix_lens + page_size - 1
600
- ) // page_size
601
- need_page = num_new_pages - num_full_new_pages
602
- end_new_pages = torch.cumsum(num_new_pages, 0)
603
- start_new_pages = end_new_pages - num_new_pages
604
- pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
605
- for i in range(len(prefix_lens)):
606
- num1 = (
607
- min(
608
- seq_lens[i],
609
- (prefix_lens[i] + page_size - 1) // page_size * page_size,
610
- )
611
- - prefix_lens[i]
612
- )
613
- if num1:
614
- out_indices[start_pos[i] : start_pos[i] + num1] = (
615
- last_loc[i] + 1 + pos_in_page[:num1].view(-1)
616
- )
617
-
618
- num2 = (
619
- seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
620
- ) * page_size
621
- if num2:
622
- pages = (
623
- free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
624
- * page_size
625
- )
626
- out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
627
- pages.view(-1, 1) + pos_in_page.view(1, -1)
628
- ).view(-1)
629
-
630
- num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
631
- if num3:
632
- out_indices[end_pos[i] - num3 : end_pos[i]] = (
633
- free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
634
- ).view(-1)
635
-
636
-
637
- class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
638
-
639
- def __init__(
640
- self,
641
- size: int,
642
- page_size: int,
643
- dtype: torch.dtype,
644
- device: str,
645
- kvcache: KVCache,
646
- need_sort: bool,
647
- ):
648
- super().__init__(size, page_size, dtype, device, kvcache, need_sort)
649
-
650
- def alloc_extend(
651
- self,
652
- prefix_lens: torch.Tensor,
653
- seq_lens: torch.Tensor,
654
- last_loc: torch.Tensor,
655
- extend_num_tokens: int,
656
- ):
657
- if self.debug_mode:
658
- assert torch.all(
659
- (last_loc + 1) % self.page_size == prefix_lens % self.page_size
660
- )
661
-
662
- estimated_num_new_pages = (
663
- (
664
- (seq_lens + self.page_size - 1) // self.page_size
665
- - (prefix_lens + self.page_size - 1) // self.page_size
666
- )
667
- .sum()
668
- .item()
669
- )
670
- if self.need_sort and estimated_num_new_pages > len(self.free_pages):
671
- self.merge_and_sort_free()
672
-
673
- if estimated_num_new_pages > len(self.free_pages):
674
- return None
675
-
676
- out_indices = torch.empty(
677
- (extend_num_tokens,), dtype=torch.int32, device=self.device
678
- )
679
-
680
- alloc_extend_kernel_ascend(
681
- prefix_lens,
682
- seq_lens,
683
- last_loc,
684
- self.free_pages,
685
- out_indices,
686
- self.page_size,
687
- self.device,
688
- )
689
-
690
- if self.debug_mode:
691
- assert len(torch.unique(out_indices)) == len(out_indices)
692
-
693
- self.free_pages = self.free_pages[estimated_num_new_pages:]
694
- return out_indices
695
-
696
- def alloc_decode(
697
- self,
698
- seq_lens: torch.Tensor,
699
- last_loc: torch.Tensor,
700
- ):
701
- if self.debug_mode:
702
- assert torch.all(
703
- (last_loc + 2) % self.page_size == seq_lens % self.page_size
704
- )
705
-
706
- need_new_pages = (seq_lens % self.page_size == 1).int()
707
- num_new_pages = need_new_pages.sum().item()
708
-
709
- if num_new_pages > len(self.free_pages):
710
- self.merge_and_sort_free()
711
-
712
- if num_new_pages > len(self.free_pages):
713
- return None
714
-
715
- end_new_pages = torch.cumsum(need_new_pages, 0)
716
- start_new_pages = end_new_pages - need_new_pages
717
- if num_new_pages == 0:
718
- out_indices = last_loc + 1
719
- else:
720
- out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
721
- start_new_pages
722
- ] * self.page_size * need_new_pages
723
-
724
- if self.debug_mode:
725
- assert len(torch.unique(out_indices)) == len(out_indices)
726
-
727
- self.free_pages = self.free_pages[num_new_pages:]
728
- return out_indices.int()
@@ -0,0 +1,147 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import torch
6
+
7
+ from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
8
+
9
+ if TYPE_CHECKING:
10
+ from sglang.srt.mem_cache.memory_pool import KVCache
11
+
12
+
13
+ def alloc_extend_kernel_ascend(
14
+ prefix_lens,
15
+ seq_lens,
16
+ last_loc,
17
+ free_pages,
18
+ out_indices,
19
+ page_size,
20
+ device,
21
+ ):
22
+ extend_lens = seq_lens - prefix_lens
23
+ end_pos = torch.cumsum(extend_lens, 0)
24
+ start_pos = end_pos - extend_lens
25
+ num_new_pages = (seq_lens + page_size - 1) // page_size - (
26
+ prefix_lens + page_size - 1
27
+ ) // page_size
28
+ num_full_new_pages = (seq_lens) // page_size - (
29
+ prefix_lens + page_size - 1
30
+ ) // page_size
31
+ need_page = num_new_pages - num_full_new_pages
32
+ end_new_pages = torch.cumsum(num_new_pages, 0)
33
+ start_new_pages = end_new_pages - num_new_pages
34
+ pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
35
+ for i in range(len(prefix_lens)):
36
+ num1 = (
37
+ min(
38
+ seq_lens[i],
39
+ (prefix_lens[i] + page_size - 1) // page_size * page_size,
40
+ )
41
+ - prefix_lens[i]
42
+ )
43
+ if num1:
44
+ out_indices[start_pos[i] : start_pos[i] + num1] = (
45
+ last_loc[i] + 1 + pos_in_page[:num1].view(-1)
46
+ )
47
+
48
+ num2 = (
49
+ seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
50
+ ) * page_size
51
+ if num2:
52
+ pages = (
53
+ free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
54
+ * page_size
55
+ )
56
+ out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
57
+ pages.view(-1, 1) + pos_in_page.view(1, -1)
58
+ ).view(-1)
59
+
60
+ num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
61
+ if num3:
62
+ out_indices[end_pos[i] - num3 : end_pos[i]] = (
63
+ free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
64
+ ).view(-1)
65
+
66
+
67
+ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
68
+
69
+ def alloc_extend(
70
+ self,
71
+ prefix_lens: torch.Tensor,
72
+ seq_lens: torch.Tensor,
73
+ last_loc: torch.Tensor,
74
+ extend_num_tokens: int,
75
+ ):
76
+ if self.debug_mode:
77
+ assert torch.all(
78
+ (last_loc + 1) % self.page_size == prefix_lens % self.page_size
79
+ )
80
+
81
+ num_new_pages = (
82
+ (
83
+ (seq_lens + self.page_size - 1) // self.page_size
84
+ - (prefix_lens + self.page_size - 1) // self.page_size
85
+ )
86
+ .sum()
87
+ .item()
88
+ )
89
+ if self.need_sort and num_new_pages > len(self.free_pages):
90
+ self.merge_and_sort_free()
91
+
92
+ if num_new_pages > len(self.free_pages):
93
+ return None
94
+
95
+ out_indices = torch.empty(
96
+ (extend_num_tokens,), dtype=torch.int32, device=self.device
97
+ )
98
+
99
+ alloc_extend_kernel_ascend(
100
+ prefix_lens,
101
+ seq_lens,
102
+ last_loc,
103
+ self.free_pages,
104
+ out_indices,
105
+ self.page_size,
106
+ self.device,
107
+ )
108
+
109
+ if self.debug_mode:
110
+ assert len(torch.unique(out_indices)) == len(out_indices)
111
+
112
+ self.free_pages = self.free_pages[num_new_pages:]
113
+ return out_indices
114
+
115
+ def alloc_decode(
116
+ self,
117
+ seq_lens: torch.Tensor,
118
+ last_loc: torch.Tensor,
119
+ ):
120
+ if self.debug_mode:
121
+ assert torch.all(
122
+ (last_loc + 2) % self.page_size == seq_lens % self.page_size
123
+ )
124
+
125
+ need_new_pages = (seq_lens % self.page_size == 1).int()
126
+ num_new_pages = need_new_pages.sum().item()
127
+
128
+ if num_new_pages > len(self.free_pages):
129
+ self.merge_and_sort_free()
130
+
131
+ if num_new_pages > len(self.free_pages):
132
+ return None
133
+
134
+ end_new_pages = torch.cumsum(need_new_pages, 0)
135
+ start_new_pages = end_new_pages - need_new_pages
136
+ if num_new_pages == 0:
137
+ out_indices = last_loc + 1
138
+ else:
139
+ out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
140
+ start_new_pages
141
+ ] * self.page_size * need_new_pages
142
+
143
+ if self.debug_mode:
144
+ assert len(torch.unique(out_indices)) == len(out_indices)
145
+
146
+ self.free_pages = self.free_pages[num_new_pages:]
147
+ return out_indices.int()
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
4
 
5
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
5
+ from typing import TYPE_CHECKING, Any, Optional
6
6
 
7
7
  import torch
8
8
 
@@ -13,6 +13,11 @@ from sglang.srt.distributed import (
13
13
  get_tensor_model_parallel_rank,
14
14
  get_tensor_model_parallel_world_size,
15
15
  )
16
+ from sglang.srt.layers.dp_attention import (
17
+ get_attention_tp_rank,
18
+ get_attention_tp_size,
19
+ is_dp_attention_enabled,
20
+ )
16
21
 
17
22
 
18
23
  def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
@@ -101,11 +106,16 @@ class HiCacheStorage(ABC):
101
106
 
102
107
  class HiCacheFile(HiCacheStorage):
103
108
 
104
- def __init__(self, file_path: str = "/tmp/hicache"):
109
+ def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
105
110
  self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
106
- tp_rank = get_tensor_model_parallel_rank()
107
- tp_size = get_tensor_model_parallel_world_size()
108
- self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else ""
111
+ if is_dp_attention_enabled():
112
+ tp_rank = get_attention_tp_rank()
113
+ tp_size = get_attention_tp_size()
114
+ else:
115
+ tp_rank = get_tensor_model_parallel_rank()
116
+ tp_size = get_tensor_model_parallel_world_size()
117
+
118
+ self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
109
119
  if not os.path.exists(self.file_path) and tp_rank == 0:
110
120
  os.makedirs(self.file_path)
111
121
  logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
@@ -849,7 +849,7 @@ class MLATokenToKVPool(KVCache):
849
849
  cache_k_rope = cache_k_rope.view(self.store_dtype)
850
850
 
851
851
  set_mla_kv_buffer_triton(
852
- self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
852
+ self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
853
853
  )
854
854
 
855
855
  def get_cpu_copy(self, indices):
@@ -951,7 +951,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
951
951
  cache_k = cache_k.to(self.dtype)
952
952
 
953
953
  if self.store_dtype != self.dtype:
954
- cache_k = cache_k.view(store_dtype)
954
+ cache_k = cache_k.view(self.store_dtype)
955
955
 
956
956
  import torch_npu
957
957
 
@@ -1070,7 +1070,7 @@ def copy_all_layer_kv_cache(
1070
1070
  num_loop = tl.cdiv(stride, BLOCK_SIZE)
1071
1071
  for i in range(num_loop):
1072
1072
  copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
1073
- mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
1073
+ mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :]
1074
1074
  value = tl.load(
1075
1075
  data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
1076
1076
  )
@@ -7,6 +7,7 @@ from functools import wraps
7
7
  import psutil
8
8
  import torch
9
9
 
10
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
10
11
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
11
12
  from sglang.srt.utils import is_npu
12
13
 
@@ -307,6 +308,9 @@ class MHATokenToKVPoolHost(HostKVCache):
307
308
 
308
309
  return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
309
310
 
311
+ def get_ksize_per_token(self):
312
+ return self.get_size_per_token() // 2
313
+
310
314
  def init_kv_buffer(self):
311
315
  if self.layout == "layer_first":
312
316
  dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
@@ -484,8 +488,8 @@ class MHATokenToKVPoolHost(HostKVCache):
484
488
  ptr_list.append(k_ptr)
485
489
  ptr_list.append(v_ptr)
486
490
  key_ = keys[index // self.page_size]
487
- key_list.append(f"{key_}_k")
488
- key_list.append(f"{key_}_v")
491
+ key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_k")
492
+ key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_v")
489
493
  element_size = (
490
494
  self.layer_num
491
495
  * self.dtype.itemsize
@@ -496,6 +500,21 @@ class MHATokenToKVPoolHost(HostKVCache):
496
500
  element_size_list = [element_size] * len(key_list)
497
501
  return key_list, ptr_list, element_size_list
498
502
 
503
+ def get_buffer_with_hash(self, keys, indices):
504
+ assert self.layout == "page_first"
505
+ assert len(keys) == (len(indices) // self.page_size)
506
+
507
+ key_list = []
508
+ buf_list = []
509
+
510
+ for key, i in zip(keys, range(0, len(indices), self.page_size)):
511
+ key_list.append(f"{key}-k")
512
+ buf_list.append(self.k_buffer[i : i + self.page_size])
513
+ key_list.append(f"{key}-v")
514
+ buf_list.append(self.v_buffer[i : i + self.page_size])
515
+
516
+ return key_list, buf_list
517
+
499
518
 
500
519
  class MLATokenToKVPoolHost(HostKVCache):
501
520
  device_pool: MLATokenToKVPool
@@ -538,6 +557,9 @@ class MLATokenToKVPoolHost(HostKVCache):
538
557
  * self.layer_num
539
558
  )
540
559
 
560
+ def get_ksize_per_token(self):
561
+ return self.get_size_per_token()
562
+
541
563
  def init_kv_buffer(self):
542
564
  if self.layout == "layer_first":
543
565
  dims = (
@@ -704,3 +726,14 @@ class MLATokenToKVPoolHost(HostKVCache):
704
726
  )
705
727
  element_size_list = [element_size] * len(key_list)
706
728
  return key_list, ptr_list, element_size_list
729
+
730
+ def get_buffer_with_hash(self, keys, indices):
731
+ assert self.layout == "page_first"
732
+ assert len(keys) == (len(indices) // self.page_size)
733
+
734
+ buf_list = []
735
+
736
+ for i in range(0, len(indices), self.page_size):
737
+ buf_list.append(self.kv_buffer[i : i + self.page_size])
738
+
739
+ return keys, buf_list
@@ -7,10 +7,15 @@ import signal
7
7
  import threading
8
8
  from abc import ABC, abstractmethod
9
9
  from functools import wraps
10
- from typing import List, Optional, Tuple
10
+ from typing import Any, List, Optional, Tuple
11
11
 
12
12
  import torch
13
13
 
14
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
15
+ from sglang.srt.layers.dp_attention import (
16
+ get_attention_tp_rank,
17
+ is_dp_attention_enabled,
18
+ )
14
19
  from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
15
20
  from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
16
21
 
@@ -167,13 +172,20 @@ class HiCacheHF3FS(HiCacheStorage):
167
172
 
168
173
  @staticmethod
169
174
  def from_env_config(
170
- rank: int, bytes_per_page: int, dtype: torch.dtype
175
+ bytes_per_page: int, dtype: torch.dtype, rank: int = None
171
176
  ) -> "HiCacheHF3FS":
172
177
  from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
173
178
  Hf3fsGlobalMetadataClient,
174
179
  Hf3fsLocalMetadataClient,
175
180
  )
176
181
 
182
+ if rank is None:
183
+ rank = (
184
+ get_attention_tp_rank()
185
+ if is_dp_attention_enabled()
186
+ else get_tensor_model_parallel_rank()
187
+ )
188
+
177
189
  config_path = os.getenv(HiCacheHF3FS.default_env_var)
178
190
  if not config_path:
179
191
  return HiCacheHF3FS(
@@ -228,15 +240,23 @@ class HiCacheHF3FS(HiCacheStorage):
228
240
  )
229
241
 
230
242
  def get(
231
- self, key: str, target_location: Optional[torch.Tensor] = None
243
+ self,
244
+ key: str,
245
+ target_location: Optional[Any] = None,
246
+ target_sizes: Optional[Any] = None,
232
247
  ) -> torch.Tensor | None:
233
- return self.batch_get([key], [target_location] if target_location else None)[0]
248
+ return self.batch_get(
249
+ [key],
250
+ [target_location] if target_location is not None else None,
251
+ [target_sizes] if target_sizes is not None else None,
252
+ )[0]
234
253
 
235
254
  @synchronized()
236
255
  def batch_get(
237
256
  self,
238
257
  keys: List[str],
239
- target_locations: Optional[List[torch.Tensor]] = None,
258
+ target_locations: Optional[Any] = None,
259
+ target_sizes: Optional[Any] = None,
240
260
  ) -> List[torch.Tensor | None]:
241
261
  page_indices = self.metadata_client.get_page_indices(self.rank, keys)
242
262
 
@@ -246,9 +266,15 @@ class HiCacheHF3FS(HiCacheStorage):
246
266
  batch_indices.append(i)
247
267
  file_offsets.append(page_index * self.bytes_per_page)
248
268
 
249
- file_results = [
250
- torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
251
- ]
269
+ if target_locations is not None:
270
+ for target_location in target_locations:
271
+ assert target_location.is_contiguous()
272
+ file_results = target_locations
273
+ else:
274
+ file_results = [
275
+ torch.empty(self.numel, dtype=self.dtype)
276
+ for _ in range(len(batch_indices))
277
+ ]
252
278
 
253
279
  futures = [
254
280
  self.executor.submit(
@@ -273,10 +299,27 @@ class HiCacheHF3FS(HiCacheStorage):
273
299
 
274
300
  return results
275
301
 
276
- def set(self, key: str, value: torch.Tensor) -> bool:
277
- return self.batch_set([key], [value])
302
+ def set(
303
+ self,
304
+ key: str,
305
+ value: Optional[Any] = None,
306
+ target_location: Optional[Any] = None,
307
+ target_sizes: Optional[Any] = None,
308
+ ) -> bool:
309
+ return self.batch_set(
310
+ [key],
311
+ [value] if value is not None else None,
312
+ [target_location] if target_location is not None else None,
313
+ [target_sizes] if target_sizes is not None else None,
314
+ )
278
315
 
279
- def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
316
+ def batch_set(
317
+ self,
318
+ keys: List[str],
319
+ values: Optional[Any] = None,
320
+ target_locations: Optional[Any] = None,
321
+ target_sizes: Optional[Any] = None,
322
+ ) -> bool:
280
323
  # Todo: Add prefix block's hash key
281
324
  key_with_prefix = [(key, "") for key in keys]
282
325
  indices = self.metadata_client.reserve_and_allocate_page_indices(
@@ -292,7 +335,8 @@ class HiCacheHF3FS(HiCacheStorage):
292
335
 
293
336
  batch_indices.append(i)
294
337
  file_offsets.append(page_index * self.bytes_per_page)
295
- file_values.append(value.contiguous())
338
+ assert value.is_contiguous()
339
+ file_values.append(value)
296
340
 
297
341
  futures = [
298
342
  self.executor.submit(
@@ -19,14 +19,13 @@ logger = logging.getLogger(__name__)
19
19
 
20
20
 
21
21
  def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
22
- local_rank = get_tensor_model_parallel_rank()
23
22
  prefix_str = ""
24
23
  if prior_hash:
25
24
  prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
26
25
  current_token_ids_bytes = np.array(token_ids).tobytes()
27
26
  current_hash_object = hashlib.sha256(current_token_ids_bytes)
28
27
  current_hash_hex = current_hash_object.hexdigest()
29
- return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
28
+ return f"{prefix_str}_{int(current_hash_hex[:16], 16)}"
30
29
 
31
30
 
32
31
  @dataclass
@@ -97,7 +96,7 @@ class MooncakeStoreConfig:
97
96
 
98
97
 
99
98
  class MooncakeStore(HiCacheStorage):
100
- def __init__(self):
99
+ def __init__(self, is_mla: bool = False):
101
100
  try:
102
101
  from mooncake.store import MooncakeDistributedStore
103
102
  except ImportError as e:
@@ -127,6 +126,7 @@ class MooncakeStore(HiCacheStorage):
127
126
  logger.info("Connect to Mooncake store successfully.")
128
127
  self.warmup()
129
128
  logger.info("Mooncake store warmup successfully.")
129
+ self.is_mla = is_mla
130
130
 
131
131
  except ValueError as e:
132
132
  logger.error("Configuration loading failed: %s", e)
@@ -223,11 +223,15 @@ class MooncakeStore(HiCacheStorage):
223
223
 
224
224
  def exists(self, keys) -> bool | dict:
225
225
  _keys = []
226
+ local_rank = get_tensor_model_parallel_rank()
226
227
  for key in keys:
227
228
  if key is None:
228
229
  return None
229
230
 
230
- _keys.append(f"{key}_k")
231
+ if self.is_mla:
232
+ _keys.append(f"{key}_k")
233
+ else:
234
+ _keys.append(f"{key}_{local_rank}_k")
231
235
  result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
232
236
  return result
233
237