sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,158 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import torch
6
+
7
+ from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
8
+
9
+ if TYPE_CHECKING:
10
+ from sglang.srt.mem_cache.memory_pool import KVCache
11
+
12
+
13
+ def alloc_extend_kernel_ascend(
14
+ prefix_lens,
15
+ seq_lens,
16
+ last_loc,
17
+ free_pages,
18
+ out_indices,
19
+ page_size,
20
+ device,
21
+ ):
22
+ extend_lens = seq_lens - prefix_lens
23
+ end_pos = torch.cumsum(extend_lens, 0)
24
+ start_pos = end_pos - extend_lens
25
+ num_new_pages = (seq_lens + page_size - 1) // page_size - (
26
+ prefix_lens + page_size - 1
27
+ ) // page_size
28
+ num_full_new_pages = (seq_lens) // page_size - (
29
+ prefix_lens + page_size - 1
30
+ ) // page_size
31
+ need_page = num_new_pages - num_full_new_pages
32
+ end_new_pages = torch.cumsum(num_new_pages, 0)
33
+ start_new_pages = end_new_pages - num_new_pages
34
+ pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
35
+ for i in range(len(prefix_lens)):
36
+ num1 = (
37
+ min(
38
+ seq_lens[i],
39
+ (prefix_lens[i] + page_size - 1) // page_size * page_size,
40
+ )
41
+ - prefix_lens[i]
42
+ )
43
+ if num1:
44
+ out_indices[start_pos[i] : start_pos[i] + num1] = (
45
+ last_loc[i] + 1 + pos_in_page[:num1].view(-1)
46
+ )
47
+
48
+ num2 = (
49
+ seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
50
+ ) * page_size
51
+ if num2:
52
+ pages = (
53
+ free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
54
+ * page_size
55
+ )
56
+ out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
57
+ pages.view(-1, 1) + pos_in_page.view(1, -1)
58
+ ).view(-1)
59
+
60
+ num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
61
+ if num3:
62
+ out_indices[end_pos[i] - num3 : end_pos[i]] = (
63
+ free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
64
+ ).view(-1)
65
+
66
+
67
+ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
68
+
69
+ def __init__(
70
+ self,
71
+ size: int,
72
+ page_size: int,
73
+ dtype: torch.dtype,
74
+ device: str,
75
+ kvcache: KVCache,
76
+ need_sort: bool,
77
+ ):
78
+ super().__init__(size, page_size, dtype, device, kvcache, need_sort, 1)
79
+
80
+ def alloc_extend(
81
+ self,
82
+ prefix_lens: torch.Tensor,
83
+ seq_lens: torch.Tensor,
84
+ last_loc: torch.Tensor,
85
+ extend_num_tokens: int,
86
+ ):
87
+ if self.debug_mode:
88
+ assert torch.all(
89
+ (last_loc + 1) % self.page_size == prefix_lens % self.page_size
90
+ )
91
+
92
+ num_new_pages = (
93
+ (
94
+ (seq_lens + self.page_size - 1) // self.page_size
95
+ - (prefix_lens + self.page_size - 1) // self.page_size
96
+ )
97
+ .sum()
98
+ .item()
99
+ )
100
+ if self.need_sort and num_new_pages > len(self.free_pages):
101
+ self.merge_and_sort_free()
102
+
103
+ if num_new_pages > len(self.free_pages):
104
+ return None
105
+
106
+ out_indices = torch.empty(
107
+ (extend_num_tokens,), dtype=torch.int32, device=self.device
108
+ )
109
+
110
+ alloc_extend_kernel_ascend(
111
+ prefix_lens,
112
+ seq_lens,
113
+ last_loc,
114
+ self.free_pages,
115
+ out_indices,
116
+ self.page_size,
117
+ self.device,
118
+ )
119
+
120
+ if self.debug_mode:
121
+ assert len(torch.unique(out_indices)) == len(out_indices)
122
+
123
+ self.free_pages = self.free_pages[num_new_pages:]
124
+ return out_indices
125
+
126
+ def alloc_decode(
127
+ self,
128
+ seq_lens: torch.Tensor,
129
+ last_loc: torch.Tensor,
130
+ ):
131
+ if self.debug_mode:
132
+ assert torch.all(
133
+ (last_loc + 2) % self.page_size == seq_lens % self.page_size
134
+ )
135
+
136
+ need_new_pages = (seq_lens % self.page_size == 1).int()
137
+ num_new_pages = need_new_pages.sum().item()
138
+
139
+ if num_new_pages > len(self.free_pages):
140
+ self.merge_and_sort_free()
141
+
142
+ if num_new_pages > len(self.free_pages):
143
+ return None
144
+
145
+ end_new_pages = torch.cumsum(need_new_pages, 0)
146
+ start_new_pages = end_new_pages - need_new_pages
147
+ if num_new_pages == 0:
148
+ out_indices = last_loc + 1
149
+ else:
150
+ out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
151
+ start_new_pages
152
+ ] * self.page_size * need_new_pages
153
+
154
+ if self.debug_mode:
155
+ assert len(torch.unique(out_indices)) == len(out_indices)
156
+
157
+ self.free_pages = self.free_pages[num_new_pages:]
158
+ return out_indices.int()
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
4
 
5
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
5
+ from typing import TYPE_CHECKING, Any, Optional
6
6
 
7
7
  import torch
8
8
 
@@ -15,7 +15,7 @@ from sglang.srt.distributed import (
15
15
  )
16
16
 
17
17
 
18
- def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
18
+ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
19
19
  hasher = hashlib.sha256()
20
20
 
21
21
  if prior_hash:
@@ -71,8 +71,10 @@ class HiRadixCache(RadixCache):
71
71
  self.tp_group = tp_cache_group
72
72
  self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
73
73
  self.enable_storage = hicache_storage_backend is not None
74
- # todo: customizable storage prefetch threshold
74
+ # todo: customizable storage prefetch threshold and timeout
75
75
  self.prefetch_threshold = 256
76
+ self.prefetch_timeout = 3 # seconds
77
+ self.prefetch_stop_policy = hicache_storage_prefetch_policy
76
78
 
77
79
  self.load_cache_event = threading.Event()
78
80
  self.cache_controller = HiCacheController(
@@ -87,13 +89,6 @@ class HiRadixCache(RadixCache):
87
89
  prefetch_threshold=self.prefetch_threshold,
88
90
  )
89
91
 
90
- self.prefetch_stop_policy = hicache_storage_prefetch_policy
91
- # todo: customizable storage prefetch timeout
92
- self.prefetch_timeout = 3 # seconds
93
- logger.info(
94
- f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
95
- )
96
-
97
92
  # record the nodes with ongoing write through
98
93
  self.ongoing_write_through = {}
99
94
  # record the node segments with ongoing load back
@@ -151,7 +146,7 @@ class HiRadixCache(RadixCache):
151
146
 
152
147
  def write_backup_storage(self, node: TreeNode):
153
148
  operation_id = self.cache_controller.write_storage(
154
- node.host_value, node.key, node.parent.get_last_hash_value()
149
+ node.host_value, node.key, node.hash_value
155
150
  )
156
151
  self.ongoing_backup[operation_id] = node
157
152
  node.protect_host()
@@ -414,18 +409,18 @@ class HiRadixCache(RadixCache):
414
409
  group=self.tp_group,
415
410
  )
416
411
  for _ in range(queue_size.item()):
417
- ack_id, hash_value, completed_tokens = (
418
- self.cache_controller.ack_backup_queue.get()
419
- )
412
+ ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
420
413
  host_node = self.ongoing_backup[ack_id]
421
- if completed_tokens == 0:
422
- host_node.hash_value = None
423
- elif completed_tokens < len(host_node.key):
424
- # backup is only partially successful, split the node
425
- new_node = self._split_node(host_node.key, host_node, completed_tokens)
426
- new_node.hash_value = hash_value
427
- else:
428
- host_node.hash_value = hash_value
414
+
415
+ if completed_tokens > 0:
416
+ if completed_tokens < len(host_node.key):
417
+ # backup is only partially successful, split the node
418
+ new_node = self._split_node(
419
+ host_node.key, host_node, completed_tokens
420
+ )
421
+ new_node.backuped_storage = True
422
+ else:
423
+ host_node.backuped_storage = True
429
424
  host_node.release_host()
430
425
  del self.ongoing_backup[ack_id]
431
426
 
@@ -471,6 +466,10 @@ class HiRadixCache(RadixCache):
471
466
  req_id
472
467
  ]
473
468
 
469
+ if operation.host_indices is None:
470
+ # prefetch has not been issued due to insufficient host memory
471
+ return True
472
+
474
473
  if not self.can_terminate_prefetch(operation):
475
474
  return False
476
475
 
@@ -565,10 +564,6 @@ class HiRadixCache(RadixCache):
565
564
  if host_indices is None:
566
565
  self.evict_host(prefetch_length)
567
566
  host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
568
- if host_indices is None:
569
- last_host_node.release_host()
570
- # no sufficient host memory to prefetch
571
- return
572
567
  operation = self.cache_controller.prefetch(
573
568
  req_id, host_indices, new_input_tokens, last_hash
574
569
  )
@@ -717,6 +712,21 @@ class HiRadixCache(RadixCache):
717
712
  node.children[child_key] = new_node
718
713
  self.evictable_size_ += len(value)
719
714
 
715
+ if self.enable_storage:
716
+ last_hash = node.get_last_hash_value()
717
+ assert (node == self.root_node) or (
718
+ last_hash is not None
719
+ ), "Parent node must have a hash value with storage enabled"
720
+ new_node.hash_value = []
721
+ for idx in range(0, len(key), self.page_size):
722
+ new_node.hash_value.append(
723
+ self.cache_controller.get_hash_str(
724
+ key[idx : idx + self.page_size],
725
+ prior_hash=last_hash,
726
+ )
727
+ )
728
+ last_hash = new_node.hash_value[-1]
729
+
720
730
  if self.cache_controller.write_policy != "write_back":
721
731
  self.inc_hit_count(new_node)
722
732
  return total_prefix_length
@@ -0,0 +1,421 @@
1
+ """Radix cache for LoRA. It's modified based on RadixCache with lora_id added to the key of nodes."""
2
+
3
+ import heapq
4
+ import time
5
+ from collections import defaultdict
6
+ from typing import TYPE_CHECKING, Any, List, Optional
7
+
8
+ import torch
9
+
10
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
11
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
12
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
13
+
14
+ if TYPE_CHECKING:
15
+ from sglang.srt.managers.schedule_batch import Req
16
+ else:
17
+ Req = Any # Placeholder for Req type when not type checking
18
+
19
+
20
+ class LoRAKey:
21
+
22
+ def __init__(self, lora_id: str, token_ids: List[int]):
23
+ self.lora_id = (
24
+ lora_id # lora_id of adaptor, should be hash value of adaptor path
25
+ )
26
+ self.token_ids = token_ids # token_ids of the key
27
+
28
+ def __len__(self):
29
+ return len(self.token_ids)
30
+
31
+
32
+ def get_child_key(key: LoRAKey):
33
+ # Here the key of children dict is the hash of lora_id + str(token_ids[0])
34
+ # So the child key can be matched only when lora_id and token_ids[0] are the same
35
+ if key.lora_id is None:
36
+ return hash(str(key.token_ids[0]))
37
+ else:
38
+ return hash(key.lora_id + str(key.token_ids[0]))
39
+
40
+
41
+ class LoRATreeNode:
42
+
43
+ counter = 0
44
+
45
+ def __init__(self, id: Optional[int] = None):
46
+ self.children = defaultdict(LoRATreeNode)
47
+ self.parent: LoRATreeNode = None
48
+ self.key: LoRAKey = None
49
+ self.value: Optional[torch.Tensor] = None
50
+ self.lock_ref = 0
51
+ self.last_access_time = time.monotonic()
52
+
53
+ self.id = LoRATreeNode.counter if id is None else id
54
+ LoRATreeNode.counter += 1
55
+
56
+ @property
57
+ def evicted(self):
58
+ return self.value is None
59
+
60
+ def __lt__(self, other: "LoRATreeNode"):
61
+ return self.last_access_time < other.last_access_time
62
+
63
+
64
+ def _key_match(key0: LoRAKey, key1: LoRAKey):
65
+ if key0.lora_id != key1.lora_id:
66
+ raise ValueError(
67
+ f"_key_match should be run on the same lora_id, but got key0.lora_id={key0.lora_id} != key1.lora_id={key1.lora_id}"
68
+ )
69
+ i = 0
70
+ for k0, k1 in zip(key0.token_ids, key1.token_ids):
71
+ if k0 != k1:
72
+ break
73
+ i += 1
74
+ return i
75
+
76
+
77
+ class LoRARadixCache(BasePrefixCache):
78
+
79
+ def __init__(
80
+ self,
81
+ req_to_token_pool: ReqToTokenPool,
82
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
83
+ page_size: int,
84
+ disable: bool = False,
85
+ ):
86
+ if page_size > 1:
87
+ raise ValueError("LoRARadixCache currently only supports page_size = 1")
88
+
89
+ if token_to_kv_pool_allocator is None:
90
+ raise ValueError(
91
+ "token_to_kv_pool_allocator is required to run LoraRadixCache"
92
+ )
93
+
94
+ self.req_to_token_pool = req_to_token_pool
95
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
96
+ self.page_size = page_size
97
+ self.disable = disable
98
+ self.device = self.token_to_kv_pool_allocator.device
99
+
100
+ self.key_match_fn = _key_match
101
+ self.get_child_key_fn = get_child_key
102
+ self.reset()
103
+
104
+ def reset(self):
105
+ self.root_node = LoRATreeNode()
106
+ self.root_node.key = LoRAKey(lora_id="", token_ids=[])
107
+ self.root_node.value = None
108
+ self.evictable_size_ = 0
109
+ self.protected_size_ = 0
110
+
111
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
112
+ raise ValueError(
113
+ "LoRARadixCache needs both token ids and lora id as inputs for matching. Please use match_prefix_with_lora_id instead."
114
+ )
115
+
116
+ def match_prefix_with_lora_id(self, key: LoRAKey, **kwargs) -> MatchResult:
117
+ """Find the matching prefix from the lora radix tree.
118
+ Args:
119
+ key: A LoRAKey to find a matching prefix.
120
+ Returns:
121
+ A tuple of a tensor of matching prefix token IDs and
122
+ the last node that contains the prefix values. Note that
123
+ this API can modify the internal state of the Radix tree.
124
+ The last node create a new child if the prefix is shorter
125
+ than the last node's value.
126
+ """
127
+ if self.disable or len(key) == 0:
128
+ return MatchResult(
129
+ device_indices=torch.empty(
130
+ (0,),
131
+ dtype=torch.int64,
132
+ device=self.device,
133
+ ),
134
+ last_device_node=self.root_node,
135
+ last_host_node=self.root_node,
136
+ )
137
+
138
+ value, last_node = self._match_prefix_helper(self.root_node, key)
139
+ if value:
140
+ value = torch.cat(value)
141
+ else:
142
+ value = torch.empty((0,), dtype=torch.int64, device=self.device)
143
+ return MatchResult(
144
+ device_indices=value,
145
+ last_device_node=last_node,
146
+ last_host_node=last_node,
147
+ )
148
+
149
+ def insert(self, key: LoRAKey, value=None):
150
+ if self.disable:
151
+ return 0
152
+
153
+ if value is None:
154
+ value = [x for x in key.token_ids]
155
+ return self._insert_helper(self.root_node, key, value)
156
+
157
+ def cache_finished_req(self, req: Req):
158
+ """Cache request when it finishes."""
159
+ if self.disable:
160
+ kv_indices = self.req_to_token_pool.req_to_token[
161
+ req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
162
+ ]
163
+ self.token_to_kv_pool_allocator.free(kv_indices)
164
+ self.req_to_token_pool.free(req.req_pool_idx)
165
+ return
166
+
167
+ token_ids = (req.origin_input_ids + req.output_ids)[:-1]
168
+ kv_indices = self.req_to_token_pool.req_to_token[
169
+ req.req_pool_idx, : len(token_ids)
170
+ ]
171
+
172
+ page_aligned_len = len(kv_indices)
173
+ page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
174
+
175
+ # Radix Cache takes one ref in memory pool
176
+ lora_key = LoRAKey(lora_id=req.lora_id, token_ids=token_ids[:page_aligned_len])
177
+ new_prefix_len = self.insert(lora_key, page_aligned_kv_indices)
178
+ self.token_to_kv_pool_allocator.free(
179
+ kv_indices[len(req.prefix_indices) : new_prefix_len]
180
+ )
181
+
182
+ # Remove req slot release the cache lock
183
+ self.req_to_token_pool.free(req.req_pool_idx)
184
+ self.dec_lock_ref(req.last_node)
185
+
186
+ def cache_unfinished_req(self, req: Req):
187
+ """Cache request when it is unfinished."""
188
+ if self.disable:
189
+ return
190
+
191
+ token_ids = req.fill_ids
192
+ kv_indices = self.req_to_token_pool.req_to_token[
193
+ req.req_pool_idx, : len(token_ids)
194
+ ]
195
+
196
+ page_aligned_len = len(kv_indices)
197
+ page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
198
+ page_aligned_token_ids = token_ids[:page_aligned_len]
199
+
200
+ # Radix Cache takes one ref in memory pool
201
+ inserted_key = LoRAKey(lora_id=req.lora_id, token_ids=page_aligned_token_ids)
202
+ new_prefix_len = self.insert(inserted_key, page_aligned_kv_indices)
203
+ self.token_to_kv_pool_allocator.free(
204
+ kv_indices[len(req.prefix_indices) : new_prefix_len]
205
+ )
206
+
207
+ # The prefix indices could be updated, reuse it
208
+ new_indices, new_last_node, _, _ = self.match_prefix_with_lora_id(inserted_key)
209
+ self.req_to_token_pool.write(
210
+ (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
211
+ new_indices[len(req.prefix_indices) :],
212
+ )
213
+
214
+ self.dec_lock_ref(req.last_node)
215
+ self.inc_lock_ref(new_last_node)
216
+
217
+ # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
218
+ req.prefix_indices = new_indices
219
+ req.last_node = new_last_node
220
+
221
+ def pretty_print(self):
222
+ self._print_helper(self.root_node, 0)
223
+ print(f"#tokens: {self.total_size()}")
224
+
225
+ def total_size(self):
226
+ return self._total_size_helper()
227
+
228
+ def evict(self, num_tokens: int):
229
+ if self.disable:
230
+ return
231
+
232
+ leaves = self._collect_leaves()
233
+ heapq.heapify(leaves)
234
+
235
+ num_evicted = 0
236
+ while num_evicted < num_tokens and len(leaves):
237
+ x = heapq.heappop(leaves)
238
+
239
+ if x == self.root_node:
240
+ break
241
+ if x.lock_ref > 0:
242
+ continue
243
+
244
+ self.token_to_kv_pool_allocator.free(x.value)
245
+ num_evicted += len(x.value)
246
+ self._delete_leaf(x)
247
+
248
+ if len(x.parent.children) == 0:
249
+ heapq.heappush(leaves, x.parent)
250
+
251
+ def inc_lock_ref(self, node: LoRATreeNode):
252
+ if self.disable:
253
+ return 0
254
+
255
+ delta = 0
256
+ while node != self.root_node:
257
+ if node.lock_ref == 0:
258
+ self.evictable_size_ -= len(node.value)
259
+ self.protected_size_ += len(node.value)
260
+ delta -= len(node.value)
261
+ node.lock_ref += 1
262
+ node = node.parent
263
+ return delta
264
+
265
+ def dec_lock_ref(self, node: LoRATreeNode):
266
+ if self.disable:
267
+ return 0
268
+
269
+ delta = 0
270
+ while node != self.root_node:
271
+ if node.lock_ref == 1:
272
+ self.evictable_size_ += len(node.value)
273
+ self.protected_size_ -= len(node.value)
274
+ delta += len(node.value)
275
+ node.lock_ref -= 1
276
+ node = node.parent
277
+ return delta
278
+
279
+ def evictable_size(self):
280
+ return self.evictable_size_
281
+
282
+ def protected_size(self):
283
+ # protected size refers to the size of the cache that is locked
284
+ return self.protected_size_
285
+
286
+ def all_values_flatten(self):
287
+ values = []
288
+
289
+ def _dfs_helper(node: LoRATreeNode):
290
+ for _, child in node.children.items():
291
+ values.append(child.value)
292
+ _dfs_helper(child)
293
+
294
+ _dfs_helper(self.root_node)
295
+ return torch.cat(values)
296
+
297
+ ##### Internal Helper Functions #####
298
+
299
+ def _match_prefix_helper(self, node: LoRATreeNode, key: LoRAKey):
300
+ node.last_access_time = time.monotonic()
301
+
302
+ child_key = self.get_child_key_fn(key)
303
+
304
+ value = []
305
+ while len(key) > 0 and child_key in node.children.keys():
306
+ child = node.children[child_key]
307
+ child.last_access_time = time.monotonic()
308
+ prefix_len = self.key_match_fn(child.key, key)
309
+ if prefix_len < len(child.key):
310
+ new_node = self._split_node(child.key, child, prefix_len)
311
+ value.append(new_node.value)
312
+ node = new_node
313
+ break
314
+ else:
315
+ value.append(child.value)
316
+ node = child
317
+ key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
318
+
319
+ if len(key):
320
+ child_key = self.get_child_key_fn(key)
321
+
322
+ return value, node
323
+
324
+ def _split_node(self, key: LoRAKey, child: LoRATreeNode, split_len: int):
325
+ # new_node -> child
326
+ new_node = LoRATreeNode()
327
+ key_split_1 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[:split_len])
328
+ key_split_2 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[split_len:])
329
+ new_node.children = {self.get_child_key_fn(key_split_2): child}
330
+ new_node.parent = child.parent
331
+ new_node.lock_ref = child.lock_ref
332
+ new_node.key = key_split_1
333
+ new_node.value = child.value[:split_len]
334
+ child.parent = new_node
335
+ child.key = key_split_2
336
+ child.value = child.value[split_len:]
337
+ new_node.parent.children[self.get_child_key_fn(key)] = new_node
338
+
339
+ return new_node
340
+
341
+ def _insert_helper(self, node: LoRATreeNode, key: LoRAKey, value):
342
+ node.last_access_time = time.monotonic()
343
+ if len(key) == 0:
344
+ return 0
345
+
346
+ child_key = self.get_child_key_fn(key)
347
+
348
+ total_prefix_length = 0
349
+ while len(key) > 0 and child_key in node.children.keys():
350
+ node = node.children[child_key]
351
+ node.last_access_time = time.monotonic()
352
+ prefix_len = self.key_match_fn(node.key, key)
353
+ total_prefix_length += prefix_len
354
+ key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
355
+ value = value[prefix_len:]
356
+
357
+ if prefix_len < len(node.key):
358
+ new_node = self._split_node(node.key, node, prefix_len)
359
+ node = new_node
360
+
361
+ if len(key):
362
+ child_key = self.get_child_key_fn(key)
363
+
364
+ if len(key):
365
+ new_node = LoRATreeNode()
366
+ new_node.parent = node
367
+ new_node.key = key
368
+ new_node.value = value
369
+ node.children[child_key] = new_node
370
+ self.evictable_size_ += len(value)
371
+ return total_prefix_length
372
+
373
+ def _print_helper(self, node: LoRATreeNode, indent: int):
374
+ """Prints the radix tree in a human-readable format."""
375
+ stack = [(node, indent)]
376
+ while stack:
377
+ current_node, current_indent = stack.pop()
378
+ print(
379
+ " " * current_indent,
380
+ len(current_node.key),
381
+ current_node.key.token_ids[:10],
382
+ f"r={current_node.lock_ref}",
383
+ )
384
+ for key, child in current_node.children.items():
385
+ stack.append((child, current_indent + 2))
386
+
387
+ assert key == self.get_child_key_fn(
388
+ child.key
389
+ ), f"{key=}, {self.get_child_key_fn(child.key)=}"
390
+
391
+ def _delete_leaf(self, node):
392
+ for k, v in node.parent.children.items():
393
+ if v == node:
394
+ break
395
+ del node.parent.children[k]
396
+ self.evictable_size_ -= len(node.key)
397
+
398
+ def _total_size_helper(self):
399
+ total_size = 0
400
+ stack = [self.root_node]
401
+ while stack:
402
+ current_node = stack.pop()
403
+ total_size += len(current_node.value)
404
+ for child in current_node.children.values():
405
+ if child.evicted:
406
+ continue
407
+ stack.append(child)
408
+ return total_size
409
+
410
+ def _collect_leaves(self):
411
+ ret_list = []
412
+ stack = [self.root_node]
413
+
414
+ while stack:
415
+ cur_node = stack.pop()
416
+ if len(cur_node.children) == 0:
417
+ ret_list.append(cur_node)
418
+ else:
419
+ stack.extend(cur_node.children.values())
420
+
421
+ return ret_list