sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Callable, Dict, List, Optional, Set, Tuple
1
+ from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
2
2
 
3
3
  import torch
4
4
 
@@ -6,10 +6,12 @@ from sglang.srt.distributed import divide
6
6
  from sglang.srt.hf_transformers_utils import AutoConfig
7
7
  from sglang.srt.lora.layers import BaseLayerWithLoRA
8
8
  from sglang.srt.lora.lora import LoRAAdapter
9
+ from sglang.srt.lora.lora_config import LoRAConfig
9
10
  from sglang.srt.lora.utils import (
10
11
  ROW_PARALLELISM_LINEAR_LORA_NAMES,
11
12
  LoRAType,
12
13
  get_hidden_dim,
14
+ get_normalized_lora_weight_names,
13
15
  get_stacked_multiply,
14
16
  get_weight_name,
15
17
  )
@@ -25,6 +27,9 @@ class LoRAMemoryPool:
25
27
  dtype: torch.dtype,
26
28
  tp_size: int,
27
29
  tp_rank: int,
30
+ max_lora_rank: int,
31
+ lora_weight_names: Tuple[Set[str], Set[str]],
32
+ base_model: torch.nn.Module,
28
33
  ):
29
34
  self.base_hf_config: AutoConfig = base_hf_config
30
35
  self.num_layer: int = base_hf_config.num_hidden_layers
@@ -32,6 +37,10 @@ class LoRAMemoryPool:
32
37
  self.dtype: torch.dtype = dtype
33
38
  self.tp_size: int = tp_size
34
39
  self.tp_rank: int = tp_rank
40
+ self.max_lora_rank: int = max_lora_rank
41
+
42
+ # lora weight names for LoRA A and B respectively.
43
+ self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
35
44
 
36
45
  # Both A_buffer and B_buffer maps lora weight names to its buffer space.
37
46
  # A_buffer contains num_layer number of row-major tensors with shape
@@ -49,6 +58,31 @@ class LoRAMemoryPool:
49
58
  # Here we don't initialize to None since None is a valid uid
50
59
  self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
51
60
 
61
+ self.init_buffers(base_model)
62
+
63
+ def can_support(self, config: Union[LoRAConfig, Iterable[LoRAConfig]]) -> bool:
64
+ """
65
+ Check if the memory pool can support the given LoRA adapters.
66
+ """
67
+
68
+ def _can_support(config: LoRAConfig) -> bool:
69
+ """
70
+ Check if the memory pool can support a single LoRA adapter.
71
+ """
72
+ if config.r > self.max_lora_rank:
73
+ return False
74
+ weights_a, weights_b = get_normalized_lora_weight_names(
75
+ config.target_modules
76
+ )
77
+ return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
78
+ self.lora_weight_names[1]
79
+ )
80
+
81
+ if isinstance(config, LoRAConfig):
82
+ return _can_support(config)
83
+ else:
84
+ return all(_can_support(x) for x in config)
85
+
52
86
  def get_lora_A_shape(
53
87
  self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
54
88
  ) -> Tuple[int]:
@@ -82,25 +116,18 @@ class LoRAMemoryPool:
82
116
  max_lora_dim,
83
117
  )
84
118
 
85
- def init_buffers(
86
- self,
87
- lora_weight_names: Tuple[Set[str]],
88
- base_model: torch.nn.Module,
89
- max_lora_dim: int,
90
- ):
91
- # lora_weight_names is a set of name pairs indicating each pair of lora modules to load
92
- # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
93
- self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
119
+ def init_buffers(self, base_model: torch.nn.Module):
94
120
  device = next(base_model.parameters()).device
95
121
 
96
- def update_buffer(
122
+ def init_buffer(
97
123
  buffer: Dict[str, List[torch.Tensor]],
98
124
  lora_weight_names: Set[str],
99
125
  get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
100
126
  ):
101
- new_weight_names = lora_weight_names - buffer.keys()
102
- for module_name in new_weight_names:
103
- lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
127
+ for module_name in lora_weight_names:
128
+ lora_shape = get_lora_shape_fn(
129
+ module_name, base_model, self.max_lora_rank
130
+ )
104
131
  buffer[module_name] = [
105
132
  torch.empty(
106
133
  lora_shape,
@@ -110,15 +137,15 @@ class LoRAMemoryPool:
110
137
  for _ in range(self.num_layer)
111
138
  ]
112
139
 
113
- update_buffer(
140
+ init_buffer(
114
141
  self.A_buffer,
115
- lora_weight_names[0],
142
+ self.lora_weight_names[0],
116
143
  self.get_lora_A_shape,
117
144
  )
118
145
 
119
- update_buffer(
146
+ init_buffer(
120
147
  self.B_buffer,
121
- lora_weight_names[1],
148
+ self.lora_weight_names[1],
122
149
  self.get_lora_B_shape,
123
150
  )
124
151
 
@@ -126,7 +153,7 @@ class LoRAMemoryPool:
126
153
  self,
127
154
  cur_uids: Set[Optional[str]],
128
155
  lora_adapters: Dict[str, LoRAAdapter],
129
- lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
156
+ lora_modules: List[Dict[str, BaseLayerWithLoRA]],
130
157
  ):
131
158
  def get_available_buffer_slot():
132
159
  for buffer_id in range(self.max_loras_per_batch):
@@ -159,12 +186,20 @@ class LoRAMemoryPool:
159
186
  uid: str,
160
187
  buffer_id: int,
161
188
  lora_adapter: LoRAAdapter,
162
- lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
189
+ lora_modules: List[Dict[str, BaseLayerWithLoRA]],
163
190
  ):
164
- def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
165
- assert (
166
- buffer_view.shape == weight.shape
167
- ), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
191
+ def load_lora_weight_tensor(
192
+ buffer_view: torch.Tensor, weight: Optional[torch.Tensor]
193
+ ):
194
+ if weight is None:
195
+ # If the particular weight is not present in the adapter, we initialize the buffer to zero
196
+ # to avoid contamination from the residual weight of the evicted adapters.
197
+ buffer_view.zero_()
198
+ else:
199
+ assert (
200
+ buffer_view.shape == weight.shape
201
+ ), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
202
+ buffer_view.copy_(weight)
168
203
 
169
204
  if uid is None:
170
205
  for i in range(self.num_layer):
@@ -176,8 +211,12 @@ class LoRAMemoryPool:
176
211
  lora_rank = lora_adapter.config.hf_config["r"]
177
212
  for layer_id in range(self.num_layer):
178
213
  layer_weights = lora_adapter.layers[layer_id].weights
179
- temp_A_buffer: Dict[str, torch.Tensor] = {}
180
- temp_B_buffer: Dict[str, torch.Tensor] = {}
214
+ temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
215
+ weight_name: None for weight_name in self.A_buffer
216
+ }
217
+ temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
218
+ weight_name: None for weight_name in self.B_buffer
219
+ }
181
220
  for name, weights in layer_weights.items():
182
221
  if "lora_A" in name:
183
222
  lora_weight_name = get_weight_name(
@@ -193,6 +232,14 @@ class LoRAMemoryPool:
193
232
  if self.tp_size > 1:
194
233
  cur_layer_modules = lora_modules[layer_id]
195
234
  for module_name, module in cur_layer_modules.items():
235
+ weight_name = get_weight_name(
236
+ module_name, self.lora_weight_names, LoRAType.LORA_A
237
+ )
238
+
239
+ if temp_A_buffer[weight_name] is None:
240
+ # Skip weight slicing if the weight is not present in the adapter
241
+ continue
242
+
196
243
  if "qkv_proj" in module_name:
197
244
  temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
198
245
  temp_A_buffer["qkv_proj"], self.tp_rank
@@ -204,9 +251,10 @@ class LoRAMemoryPool:
204
251
  )
205
252
  )
206
253
  else:
207
- weight_name = get_weight_name(
208
- module_name, self.lora_weight_names, LoRAType.LORA_A
209
- )
254
+ # TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
255
+ # Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
256
+ # B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
257
+ # FlashInfer LoRA backend.
210
258
  temp_A_buffer[weight_name] = module.slice_lora_a_weights(
211
259
  temp_A_buffer[weight_name], self.tp_rank
212
260
  )
@@ -219,8 +267,7 @@ class LoRAMemoryPool:
219
267
  buffer_view = self.A_buffer[name][layer_id][buffer_id][
220
268
  : lora_rank * c, :
221
269
  ]
222
- check_lora_weight_shape(buffer_view, weights)
223
- buffer_view.copy_(weights)
270
+ load_lora_weight_tensor(buffer_view, weights)
224
271
 
225
272
  for name, weights in temp_B_buffer.items():
226
273
  c = get_stacked_multiply(name)
@@ -229,14 +276,15 @@ class LoRAMemoryPool:
229
276
  buffer_view = self.B_buffer[name][layer_id][stacked_id][
230
277
  buffer_id
231
278
  ][:, :lora_rank]
232
- check_lora_weight_shape(buffer_view, weights[stacked_id])
233
- buffer_view.copy_(weights[stacked_id])
279
+ weight_slice = (
280
+ weights[stacked_id] if weights is not None else None
281
+ )
282
+ load_lora_weight_tensor(buffer_view, weight_slice)
234
283
  else:
235
284
  buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
236
285
  :, :lora_rank
237
286
  ]
238
- check_lora_weight_shape(buffer_view, weights)
239
- buffer_view.copy_(weights)
287
+ load_lora_weight_tensor(buffer_view, weights)
240
288
 
241
289
  def get_tensor(
242
290
  self, weight_name: str, layer_id: int, lora_type: LoRAType
sglang/srt/lora/utils.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import re
2
2
  from dataclasses import dataclass
3
3
  from enum import Enum
4
- from typing import List, Optional, Set, Tuple
4
+ from typing import Iterable, Optional, Set, Tuple
5
5
 
6
6
  import torch
7
7
 
@@ -106,9 +106,11 @@ def get_hidden_dim(
106
106
  raise NotImplementedError()
107
107
 
108
108
 
109
- def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
109
+ def get_normalized_lora_weight_names(
110
+ target_modules: Iterable[str],
111
+ ) -> Tuple[set[str], set[str]]:
110
112
  """
111
- Mapping a target module name to names of the normalized LoRA weights.
113
+ Mapping a list of target module name to names of the normalized LoRA weights.
112
114
  Returned tuple contains (name for Lora A, name for Lora B)
113
115
  """
114
116
  params_mapping = {
@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
120
122
  "qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
121
123
  "gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
122
124
  }
123
- stacked = params_mapping.get(name, ([name], [name]))
124
- return stacked
125
+
126
+ result = (set(), set())
127
+ for name in target_modules:
128
+ lora_a, lora_b = params_mapping.get(name, ([name], [name]))
129
+ result[0].update(lora_a)
130
+ result[1].update(lora_b)
131
+ return result
125
132
 
126
133
 
127
134
  def get_stacked_multiply(module_name: str) -> int:
@@ -25,6 +25,8 @@ if TYPE_CHECKING:
25
25
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
26
26
  from sglang.srt.mem_cache.memory_pool_host import HostKVCache
27
27
 
28
+ from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
29
+
28
30
  logger = logging.getLogger(__name__)
29
31
 
30
32
 
@@ -159,6 +161,57 @@ class TransferBuffer:
159
161
  self.buffers.queue.clear()
160
162
 
161
163
 
164
+ class StorageOperation:
165
+ counter = 0
166
+
167
+ def __init__(
168
+ self,
169
+ host_indices: torch.Tensor,
170
+ token_ids: List[int],
171
+ last_hash: Optional[str] = None,
172
+ ):
173
+ self.host_indices = host_indices
174
+ self.token_ids = token_ids
175
+ self.last_hash = last_hash
176
+ self.completed_tokens = 0
177
+ self.hash_value = []
178
+
179
+ self.id = StorageOperation.counter
180
+ StorageOperation.counter += 1
181
+
182
+ def __lt__(self, other: "StorageOperation"):
183
+ return self.id < other.id
184
+
185
+
186
+ class PrefetchOperation(StorageOperation):
187
+ def __init__(
188
+ self,
189
+ request_id: str,
190
+ host_indices: torch.Tensor,
191
+ token_ids: List[int],
192
+ last_hash: Optional[str] = None,
193
+ ):
194
+ self.request_id = request_id
195
+
196
+ self._done_flag = False
197
+ self._lock = threading.Lock()
198
+
199
+ super().__init__(host_indices, token_ids, last_hash)
200
+
201
+ def increment(self, num_tokens: int):
202
+ with self._lock:
203
+ if self._done_flag:
204
+ return
205
+ self.completed_tokens += num_tokens
206
+
207
+ def mark_done(self):
208
+ with self._lock:
209
+ self._done_flag = True
210
+
211
+ def is_done(self) -> bool:
212
+ return self._done_flag
213
+
214
+
162
215
  class HiCacheController:
163
216
 
164
217
  def __init__(
@@ -166,9 +219,12 @@ class HiCacheController:
166
219
  token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
167
220
  mem_pool_host: HostKVCache,
168
221
  page_size: int,
222
+ tp_group: torch.distributed.ProcessGroup,
169
223
  load_cache_event: threading.Event = None,
170
224
  write_policy: str = "write_through_selective",
171
225
  io_backend: str = "",
226
+ storage_backend: Optional[str] = None,
227
+ prefetch_threshold: int = 256,
172
228
  ):
173
229
  self.mem_pool_device_allocator = token_to_kv_pool_allocator
174
230
  self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
@@ -186,6 +242,25 @@ class HiCacheController:
186
242
  else:
187
243
  self.io_backend = io_backend
188
244
 
245
+ self.enable_storage = False
246
+ # todo: move backend initialization to storage backend module
247
+ if storage_backend is not None:
248
+ # create a new communication group for synchronizing storage operations across TP workers
249
+ self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
250
+ if self.tp_world_size > 1:
251
+ group_ranks = torch.distributed.get_process_group_ranks(tp_group)
252
+ self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
253
+
254
+ if storage_backend == "file":
255
+ self.storage_backend = HiCacheFile()
256
+ self.enable_storage = True
257
+ # todo: threshold policy for prefetching
258
+ self.prefetch_threshold = max(prefetch_threshold, self.page_size)
259
+ else:
260
+ raise NotImplementedError(
261
+ f"Unsupported storage backend: {storage_backend}"
262
+ )
263
+
189
264
  self.load_cache_event = load_cache_event
190
265
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
191
266
  self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
@@ -218,9 +293,26 @@ class HiCacheController:
218
293
  self.load_thread = threading.Thread(
219
294
  target=self.load_thread_func_layer_by_layer, daemon=True
220
295
  )
296
+
221
297
  self.write_thread.start()
222
298
  self.load_thread.start()
223
299
 
300
+ if self.enable_storage:
301
+ self.prefetch_thread = threading.Thread(
302
+ target=self.prefetch_thread_func, daemon=True
303
+ )
304
+ self.backup_thread = threading.Thread(
305
+ target=self.backup_thread_func, daemon=True
306
+ )
307
+ self.prefetch_queue = Queue()
308
+ self.backup_queue = Queue()
309
+
310
+ self.prefetch_revoke_queue = Queue()
311
+ self.ack_backup_queue = Queue()
312
+
313
+ self.prefetch_thread.start()
314
+ self.backup_thread.start()
315
+
224
316
  def reset(self):
225
317
  self.stop_event.set()
226
318
  self.write_thread.join()
@@ -232,6 +324,13 @@ class HiCacheController:
232
324
  self.load_buffer.clear()
233
325
  self.ack_write_queue.queue.clear()
234
326
  self.ack_load_queue.queue.clear()
327
+ if self.enable_storage:
328
+ self.prefetch_thread.join()
329
+ self.backup_thread.join()
330
+ self.prefetch_queue.queue.clear()
331
+ self.backup_queue.queue.clear()
332
+ self.prefetch_revoke_queue.queue.clear()
333
+ self.ack_backup_queue.queue.clear()
235
334
 
236
335
  self.write_thread = threading.Thread(
237
336
  target=self.write_thread_func_direct, daemon=True
@@ -243,6 +342,16 @@ class HiCacheController:
243
342
  self.write_thread.start()
244
343
  self.load_thread.start()
245
344
 
345
+ if self.enable_storage:
346
+ self.prefetch_thread = threading.Thread(
347
+ target=self.prefetch_thread_func, daemon=True
348
+ )
349
+ self.backup_thread = threading.Thread(
350
+ target=self.backup_thread_func, daemon=True
351
+ )
352
+ self.prefetch_thread.start()
353
+ self.backup_thread.start()
354
+
246
355
  def write(
247
356
  self,
248
357
  device_indices: torch.Tensor,
@@ -256,6 +365,7 @@ class HiCacheController:
256
365
  if host_indices is None:
257
366
  return None
258
367
  self.mem_pool_host.protect_write(host_indices)
368
+ torch.cuda.current_stream().synchronize()
259
369
  self.write_queue.put(
260
370
  CacheOperation(host_indices, device_indices, node_id, priority)
261
371
  )
@@ -383,3 +493,181 @@ class HiCacheController:
383
493
  raise ValueError(
384
494
  f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
385
495
  )
496
+
497
+ def prefetch(
498
+ self,
499
+ request_id: str,
500
+ host_indices: torch.Tensor,
501
+ new_input_tokens: List[int],
502
+ last_hash: Optional[str] = None,
503
+ ) -> int:
504
+ """
505
+ Prefetch KV caches from storage backend to host memory.
506
+ """
507
+ operation = PrefetchOperation(
508
+ request_id, host_indices, new_input_tokens, last_hash
509
+ )
510
+ self.prefetch_queue.put(operation)
511
+ return operation
512
+
513
+ def terminate_prefetch(self, operation):
514
+ operation.mark_done()
515
+ return operation.completed_tokens, operation.hash_value
516
+
517
+ def prefetch_io_aux_func(self):
518
+ """
519
+ Auxiliary function conducting IO operations for prefetching.
520
+ """
521
+ while not self.stop_event.is_set():
522
+ try:
523
+ operation = self.prefetch_buffer.get(block=True, timeout=1)
524
+ for h in operation.hash_value:
525
+ page_data = self.storage_backend.get(h)
526
+ if page_data is None:
527
+ logger.warning(
528
+ f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
529
+ )
530
+ break
531
+ self.mem_pool_host.set_from_flat_data_page(
532
+ operation.host_indices[operation.completed_tokens],
533
+ page_data,
534
+ )
535
+ operation.increment(self.page_size)
536
+ if operation.is_done():
537
+ # operation terminated by controller, release pre-allocated memory
538
+ self.mem_pool_host.free(
539
+ operation.host_indices[operation.completed_tokens :]
540
+ )
541
+ break
542
+ except Empty:
543
+ continue
544
+
545
+ def prefetch_thread_func(self):
546
+ """
547
+ Manage prefetching operations from storage backend to host memory.
548
+ """
549
+ self.prefetch_buffer = Queue()
550
+ aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
551
+ aux_thread.start()
552
+ while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
553
+ try:
554
+ operation = self.prefetch_queue.get(block=True, timeout=1)
555
+ if operation is None:
556
+ continue
557
+
558
+ last_hash = operation.last_hash
559
+ tokens_to_fetch = operation.token_ids
560
+
561
+ storage_hit_count = 0
562
+ remaining_tokens = len(tokens_to_fetch)
563
+ hash_value = []
564
+ while remaining_tokens >= self.page_size:
565
+ last_hash = get_hash_str(
566
+ tokens_to_fetch[
567
+ storage_hit_count : storage_hit_count + self.page_size
568
+ ],
569
+ last_hash,
570
+ )
571
+ if self.storage_backend.exists(last_hash):
572
+ storage_hit_count += self.page_size
573
+ hash_value.append(last_hash)
574
+ remaining_tokens -= self.page_size
575
+ else:
576
+ break
577
+
578
+ if self.tp_world_size > 1:
579
+ storage_hit_count_tensor = torch.tensor(
580
+ storage_hit_count, dtype=torch.int
581
+ )
582
+ torch.distributed.all_reduce(
583
+ storage_hit_count_tensor,
584
+ op=torch.distributed.ReduceOp.MIN,
585
+ group=self.tp_group,
586
+ )
587
+ storage_hit_count = storage_hit_count_tensor.item()
588
+
589
+ if storage_hit_count < self.prefetch_threshold:
590
+ # not to prefetch if not enough benefits
591
+ self.prefetch_revoke_queue.put(operation.request_id)
592
+ logger.debug(
593
+ f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
594
+ )
595
+ else:
596
+ operation.hash_value = hash_value[
597
+ : (storage_hit_count // self.page_size)
598
+ ]
599
+ # free the pre-allocated memory for pages that are not hit
600
+ self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
601
+ operation.host_indices = operation.host_indices[:storage_hit_count]
602
+ logger.debug(
603
+ f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
604
+ )
605
+ self.prefetch_buffer.put(operation)
606
+
607
+ except Empty:
608
+ continue
609
+
610
+ def write_storage(
611
+ self,
612
+ host_indices: torch.Tensor,
613
+ token_ids: List[int],
614
+ last_hash: Optional[str] = None,
615
+ ) -> int:
616
+ """
617
+ Write KV caches from host memory to storage backend.
618
+ """
619
+ operation = StorageOperation(host_indices, token_ids, last_hash)
620
+ self.backup_queue.put(operation)
621
+ return operation.id
622
+
623
+ def backup_thread_func(self):
624
+ """
625
+ Manage backup operations from host memory to storage backend.
626
+ """
627
+ while not self.stop_event.is_set():
628
+ try:
629
+ operation = self.backup_queue.get(block=True, timeout=1)
630
+ if operation is None:
631
+ continue
632
+
633
+ last_hash = operation.last_hash
634
+ tokens_to_backup = operation.token_ids
635
+
636
+ for i in range(0, len(tokens_to_backup), self.page_size):
637
+ last_hash = get_hash_str(
638
+ tokens_to_backup[i : i + self.page_size], last_hash
639
+ )
640
+ success = self.storage_backend.set(
641
+ last_hash,
642
+ self.mem_pool_host.get_flat_data_page(
643
+ operation.host_indices[i]
644
+ ),
645
+ )
646
+ if not success:
647
+ logger.warning(f"Failed to write page {last_hash} to storage.")
648
+ break
649
+ operation.completed_tokens += self.page_size
650
+ operation.hash_value.append(last_hash)
651
+
652
+ min_completed_tokens = operation.completed_tokens
653
+ if self.tp_world_size > 1:
654
+ completed_tokens_tensor = torch.tensor(
655
+ min_completed_tokens, dtype=torch.int
656
+ )
657
+ torch.distributed.all_reduce(
658
+ completed_tokens_tensor,
659
+ op=torch.distributed.ReduceOp.MIN,
660
+ group=self.tp_group,
661
+ )
662
+ min_completed_tokens = completed_tokens_tensor.item()
663
+
664
+ self.ack_backup_queue.put(
665
+ (
666
+ operation.id,
667
+ operation.hash_value[: min_completed_tokens // self.page_size],
668
+ min_completed_tokens,
669
+ )
670
+ )
671
+
672
+ except Empty:
673
+ continue