sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Callable, Dict, List, Optional, Set, Tuple
1
+ from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
2
2
 
3
3
  import torch
4
4
 
@@ -6,10 +6,12 @@ from sglang.srt.distributed import divide
6
6
  from sglang.srt.hf_transformers_utils import AutoConfig
7
7
  from sglang.srt.lora.layers import BaseLayerWithLoRA
8
8
  from sglang.srt.lora.lora import LoRAAdapter
9
+ from sglang.srt.lora.lora_config import LoRAConfig
9
10
  from sglang.srt.lora.utils import (
10
11
  ROW_PARALLELISM_LINEAR_LORA_NAMES,
11
12
  LoRAType,
12
13
  get_hidden_dim,
14
+ get_normalized_lora_weight_names,
13
15
  get_stacked_multiply,
14
16
  get_weight_name,
15
17
  )
@@ -25,6 +27,9 @@ class LoRAMemoryPool:
25
27
  dtype: torch.dtype,
26
28
  tp_size: int,
27
29
  tp_rank: int,
30
+ max_lora_rank: int,
31
+ lora_weight_names: Tuple[Set[str], Set[str]],
32
+ base_model: torch.nn.Module,
28
33
  ):
29
34
  self.base_hf_config: AutoConfig = base_hf_config
30
35
  self.num_layer: int = base_hf_config.num_hidden_layers
@@ -32,6 +37,10 @@ class LoRAMemoryPool:
32
37
  self.dtype: torch.dtype = dtype
33
38
  self.tp_size: int = tp_size
34
39
  self.tp_rank: int = tp_rank
40
+ self.max_lora_rank: int = max_lora_rank
41
+
42
+ # lora weight names for LoRA A and B respectively.
43
+ self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
35
44
 
36
45
  # Both A_buffer and B_buffer maps lora weight names to its buffer space.
37
46
  # A_buffer contains num_layer number of row-major tensors with shape
@@ -49,6 +58,31 @@ class LoRAMemoryPool:
49
58
  # Here we don't initialize to None since None is a valid uid
50
59
  self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
51
60
 
61
+ self.init_buffers(base_model)
62
+
63
+ def can_support(self, config: Union[LoRAConfig, Iterable[LoRAConfig]]) -> bool:
64
+ """
65
+ Check if the memory pool can support the given LoRA adapters.
66
+ """
67
+
68
+ def _can_support(config: LoRAConfig) -> bool:
69
+ """
70
+ Check if the memory pool can support a single LoRA adapter.
71
+ """
72
+ if config.r > self.max_lora_rank:
73
+ return False
74
+ weights_a, weights_b = get_normalized_lora_weight_names(
75
+ config.target_modules
76
+ )
77
+ return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
78
+ self.lora_weight_names[1]
79
+ )
80
+
81
+ if isinstance(config, LoRAConfig):
82
+ return _can_support(config)
83
+ else:
84
+ return all(_can_support(x) for x in config)
85
+
52
86
  def get_lora_A_shape(
53
87
  self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
54
88
  ) -> Tuple[int]:
@@ -82,25 +116,18 @@ class LoRAMemoryPool:
82
116
  max_lora_dim,
83
117
  )
84
118
 
85
- def init_buffers(
86
- self,
87
- lora_weight_names: Tuple[Set[str]],
88
- base_model: torch.nn.Module,
89
- max_lora_dim: int,
90
- ):
91
- # lora_weight_names is a set of name pairs indicating each pair of lora modules to load
92
- # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
93
- self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
119
+ def init_buffers(self, base_model: torch.nn.Module):
94
120
  device = next(base_model.parameters()).device
95
121
 
96
- def update_buffer(
122
+ def init_buffer(
97
123
  buffer: Dict[str, List[torch.Tensor]],
98
124
  lora_weight_names: Set[str],
99
125
  get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
100
126
  ):
101
- new_weight_names = lora_weight_names - buffer.keys()
102
- for module_name in new_weight_names:
103
- lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
127
+ for module_name in lora_weight_names:
128
+ lora_shape = get_lora_shape_fn(
129
+ module_name, base_model, self.max_lora_rank
130
+ )
104
131
  buffer[module_name] = [
105
132
  torch.empty(
106
133
  lora_shape,
@@ -110,15 +137,15 @@ class LoRAMemoryPool:
110
137
  for _ in range(self.num_layer)
111
138
  ]
112
139
 
113
- update_buffer(
140
+ init_buffer(
114
141
  self.A_buffer,
115
- lora_weight_names[0],
142
+ self.lora_weight_names[0],
116
143
  self.get_lora_A_shape,
117
144
  )
118
145
 
119
- update_buffer(
146
+ init_buffer(
120
147
  self.B_buffer,
121
- lora_weight_names[1],
148
+ self.lora_weight_names[1],
122
149
  self.get_lora_B_shape,
123
150
  )
124
151
 
@@ -161,10 +188,18 @@ class LoRAMemoryPool:
161
188
  lora_adapter: LoRAAdapter,
162
189
  lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
163
190
  ):
164
- def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
165
- assert (
166
- buffer_view.shape == weight.shape
167
- ), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
191
+ def load_lora_weight_tensor(
192
+ buffer_view: torch.Tensor, weight: Optional[torch.Tensor]
193
+ ):
194
+ if weight is None:
195
+ # If the particular weight is not present in the adapter, we initialize the buffer to zero
196
+ # to avoid contamination from the residual weight of the evicted adapters.
197
+ buffer_view.zero_()
198
+ else:
199
+ assert (
200
+ buffer_view.shape == weight.shape
201
+ ), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
202
+ buffer_view.copy_(weight)
168
203
 
169
204
  if uid is None:
170
205
  for i in range(self.num_layer):
@@ -176,8 +211,12 @@ class LoRAMemoryPool:
176
211
  lora_rank = lora_adapter.config.hf_config["r"]
177
212
  for layer_id in range(self.num_layer):
178
213
  layer_weights = lora_adapter.layers[layer_id].weights
179
- temp_A_buffer: Dict[str, torch.Tensor] = {}
180
- temp_B_buffer: Dict[str, torch.Tensor] = {}
214
+ temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
215
+ weight_name: None for weight_name in self.A_buffer
216
+ }
217
+ temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
218
+ weight_name: None for weight_name in self.B_buffer
219
+ }
181
220
  for name, weights in layer_weights.items():
182
221
  if "lora_A" in name:
183
222
  lora_weight_name = get_weight_name(
@@ -193,6 +232,14 @@ class LoRAMemoryPool:
193
232
  if self.tp_size > 1:
194
233
  cur_layer_modules = lora_modules[layer_id]
195
234
  for module_name, module in cur_layer_modules.items():
235
+ weight_name = get_weight_name(
236
+ module_name, self.lora_weight_names, LoRAType.LORA_A
237
+ )
238
+
239
+ if temp_A_buffer[weight_name] is None:
240
+ # Skip weight slicing if the weight is not present in the adapter
241
+ continue
242
+
196
243
  if "qkv_proj" in module_name:
197
244
  temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
198
245
  temp_A_buffer["qkv_proj"], self.tp_rank
@@ -204,9 +251,10 @@ class LoRAMemoryPool:
204
251
  )
205
252
  )
206
253
  else:
207
- weight_name = get_weight_name(
208
- module_name, self.lora_weight_names, LoRAType.LORA_A
209
- )
254
+ # TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
255
+ # Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
256
+ # B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
257
+ # FlashInfer LoRA backend.
210
258
  temp_A_buffer[weight_name] = module.slice_lora_a_weights(
211
259
  temp_A_buffer[weight_name], self.tp_rank
212
260
  )
@@ -219,8 +267,7 @@ class LoRAMemoryPool:
219
267
  buffer_view = self.A_buffer[name][layer_id][buffer_id][
220
268
  : lora_rank * c, :
221
269
  ]
222
- check_lora_weight_shape(buffer_view, weights)
223
- buffer_view.copy_(weights)
270
+ load_lora_weight_tensor(buffer_view, weights)
224
271
 
225
272
  for name, weights in temp_B_buffer.items():
226
273
  c = get_stacked_multiply(name)
@@ -229,14 +276,15 @@ class LoRAMemoryPool:
229
276
  buffer_view = self.B_buffer[name][layer_id][stacked_id][
230
277
  buffer_id
231
278
  ][:, :lora_rank]
232
- check_lora_weight_shape(buffer_view, weights[stacked_id])
233
- buffer_view.copy_(weights[stacked_id])
279
+ weight_slice = (
280
+ weights[stacked_id] if weights is not None else None
281
+ )
282
+ load_lora_weight_tensor(buffer_view, weight_slice)
234
283
  else:
235
284
  buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
236
285
  :, :lora_rank
237
286
  ]
238
- check_lora_weight_shape(buffer_view, weights)
239
- buffer_view.copy_(weights)
287
+ load_lora_weight_tensor(buffer_view, weights)
240
288
 
241
289
  def get_tensor(
242
290
  self, weight_name: str, layer_id: int, lora_type: LoRAType
sglang/srt/lora/utils.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import re
2
2
  from dataclasses import dataclass
3
3
  from enum import Enum
4
- from typing import List, Optional, Set, Tuple
4
+ from typing import Iterable, Optional, Set, Tuple
5
5
 
6
6
  import torch
7
7
 
@@ -106,9 +106,11 @@ def get_hidden_dim(
106
106
  raise NotImplementedError()
107
107
 
108
108
 
109
- def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
109
+ def get_normalized_lora_weight_names(
110
+ target_modules: Iterable[str],
111
+ ) -> Tuple[set[str], set[str]]:
110
112
  """
111
- Mapping a target module name to names of the normalized LoRA weights.
113
+ Mapping a list of target module name to names of the normalized LoRA weights.
112
114
  Returned tuple contains (name for Lora A, name for Lora B)
113
115
  """
114
116
  params_mapping = {
@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
120
122
  "qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
121
123
  "gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
122
124
  }
123
- stacked = params_mapping.get(name, ([name], [name]))
124
- return stacked
125
+
126
+ result = (set(), set())
127
+ for name in target_modules:
128
+ lora_a, lora_b = params_mapping.get(name, ([name], [name]))
129
+ result[0].update(lora_a)
130
+ result[1].update(lora_b)
131
+ return result
125
132
 
126
133
 
127
134
  def get_stacked_multiply(module_name: str) -> int:
@@ -25,6 +25,8 @@ if TYPE_CHECKING:
25
25
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
26
26
  from sglang.srt.mem_cache.memory_pool_host import HostKVCache
27
27
 
28
+ from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
29
+
28
30
  logger = logging.getLogger(__name__)
29
31
 
30
32
 
@@ -159,6 +161,57 @@ class TransferBuffer:
159
161
  self.buffers.queue.clear()
160
162
 
161
163
 
164
+ class StorageOperation:
165
+ counter = 0
166
+
167
+ def __init__(
168
+ self,
169
+ host_indices: torch.Tensor,
170
+ token_ids: List[int],
171
+ last_hash: Optional[str] = None,
172
+ ):
173
+ self.host_indices = host_indices
174
+ self.token_ids = token_ids
175
+ self.last_hash = last_hash
176
+ self.completed_tokens = 0
177
+ self.hash_value = []
178
+
179
+ self.id = StorageOperation.counter
180
+ StorageOperation.counter += 1
181
+
182
+ def __lt__(self, other: "StorageOperation"):
183
+ return self.id < other.id
184
+
185
+
186
+ class PrefetchOperation(StorageOperation):
187
+ def __init__(
188
+ self,
189
+ request_id: str,
190
+ host_indices: torch.Tensor,
191
+ token_ids: List[int],
192
+ last_hash: Optional[str] = None,
193
+ ):
194
+ self.request_id = request_id
195
+
196
+ self._done_flag = False
197
+ self._lock = threading.Lock()
198
+
199
+ super().__init__(host_indices, token_ids, last_hash)
200
+
201
+ def increment(self, num_tokens: int):
202
+ with self._lock:
203
+ if self._done_flag:
204
+ return
205
+ self.completed_tokens += num_tokens
206
+
207
+ def mark_done(self):
208
+ with self._lock:
209
+ self._done_flag = True
210
+
211
+ def is_done(self) -> bool:
212
+ return self._done_flag
213
+
214
+
162
215
  class HiCacheController:
163
216
 
164
217
  def __init__(
@@ -169,6 +222,8 @@ class HiCacheController:
169
222
  load_cache_event: threading.Event = None,
170
223
  write_policy: str = "write_through_selective",
171
224
  io_backend: str = "",
225
+ storage_backend: Optional[str] = None,
226
+ prefetch_threshold: int = 256,
172
227
  ):
173
228
  self.mem_pool_device_allocator = token_to_kv_pool_allocator
174
229
  self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
@@ -186,6 +241,19 @@ class HiCacheController:
186
241
  else:
187
242
  self.io_backend = io_backend
188
243
 
244
+ self.enable_storage = False
245
+ # todo: move backend initialization to storage backend module
246
+ if storage_backend is not None:
247
+ if storage_backend == "file":
248
+ self.storage_backend = HiCacheFile()
249
+ self.enable_storage = True
250
+ # todo: threshold policy for prefetching
251
+ self.prefetch_threshold = prefetch_threshold
252
+ else:
253
+ raise NotImplementedError(
254
+ f"Unsupported storage backend: {storage_backend}"
255
+ )
256
+
189
257
  self.load_cache_event = load_cache_event
190
258
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
191
259
  self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
@@ -218,9 +286,26 @@ class HiCacheController:
218
286
  self.load_thread = threading.Thread(
219
287
  target=self.load_thread_func_layer_by_layer, daemon=True
220
288
  )
289
+
221
290
  self.write_thread.start()
222
291
  self.load_thread.start()
223
292
 
293
+ if self.enable_storage:
294
+ self.prefetch_thread = threading.Thread(
295
+ target=self.prefetch_thread_func, daemon=True
296
+ )
297
+ self.backup_thread = threading.Thread(
298
+ target=self.backup_thread_func, daemon=True
299
+ )
300
+ self.prefetch_queue = Queue()
301
+ self.backup_queue = Queue()
302
+
303
+ self.prefetch_revoke_queue = Queue()
304
+ self.ack_backup_queue = Queue()
305
+
306
+ self.prefetch_thread.start()
307
+ self.backup_thread.start()
308
+
224
309
  def reset(self):
225
310
  self.stop_event.set()
226
311
  self.write_thread.join()
@@ -232,6 +317,13 @@ class HiCacheController:
232
317
  self.load_buffer.clear()
233
318
  self.ack_write_queue.queue.clear()
234
319
  self.ack_load_queue.queue.clear()
320
+ if self.enable_storage:
321
+ self.prefetch_thread.join()
322
+ self.backup_thread.join()
323
+ self.prefetch_queue.queue.clear()
324
+ self.backup_queue.queue.clear()
325
+ self.prefetch_revoke_queue.queue.clear()
326
+ self.ack_backup_queue.queue.clear()
235
327
 
236
328
  self.write_thread = threading.Thread(
237
329
  target=self.write_thread_func_direct, daemon=True
@@ -243,6 +335,16 @@ class HiCacheController:
243
335
  self.write_thread.start()
244
336
  self.load_thread.start()
245
337
 
338
+ if self.enable_storage:
339
+ self.prefetch_thread = threading.Thread(
340
+ target=self.prefetch_thread_func, daemon=True
341
+ )
342
+ self.backup_thread = threading.Thread(
343
+ target=self.backup_thread_func, daemon=True
344
+ )
345
+ self.prefetch_thread.start()
346
+ self.backup_thread.start()
347
+
246
348
  def write(
247
349
  self,
248
350
  device_indices: torch.Tensor,
@@ -383,3 +485,142 @@ class HiCacheController:
383
485
  raise ValueError(
384
486
  f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
385
487
  )
488
+
489
+ def prefetch(
490
+ self,
491
+ request_id: str,
492
+ host_indices: torch.Tensor,
493
+ new_input_tokens: List[int],
494
+ last_hash: Optional[str] = None,
495
+ ) -> int:
496
+ """
497
+ Prefetch KV caches from storage backend to host memory.
498
+ """
499
+ operation = PrefetchOperation(
500
+ request_id, host_indices, new_input_tokens, last_hash
501
+ )
502
+ self.prefetch_queue.put(operation)
503
+ return operation
504
+
505
+ def terminate_prefetch(self, operation):
506
+ operation.mark_done()
507
+ return operation.completed_tokens, operation.hash_value
508
+
509
+ def prefetch_io_aux_func(self):
510
+ """
511
+ Auxiliary function conducting IO operations for prefetching.
512
+ """
513
+ while not self.stop_event.is_set():
514
+ try:
515
+ operation = self.prefetch_buffer.get(block=True, timeout=1)
516
+ for h in operation.hash_value:
517
+ page_data = self.storage_backend.get(h)
518
+ if page_data is None:
519
+ logger.warning(
520
+ f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
521
+ )
522
+ break
523
+ self.mem_pool_host.set_from_flat_data_page(
524
+ operation.host_indices[operation.completed_tokens],
525
+ page_data,
526
+ )
527
+ operation.increment(self.page_size)
528
+ if operation.is_done():
529
+ # operation terminated by controller, release pre-allocated memory
530
+ self.mem_pool_host.free(
531
+ operation.host_indices[operation.completed_tokens :]
532
+ )
533
+ break
534
+ except Empty:
535
+ continue
536
+
537
+ def prefetch_thread_func(self):
538
+ """
539
+ Manage prefetching operations from storage backend to host memory.
540
+ """
541
+ self.prefetch_buffer = Queue()
542
+ aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
543
+ aux_thread.start()
544
+ while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
545
+ try:
546
+ operation = self.prefetch_queue.get(block=True, timeout=1)
547
+ if operation is None:
548
+ continue
549
+
550
+ last_hash = operation.last_hash
551
+ tokens_to_fetch = operation.token_ids
552
+
553
+ storage_hit_count = 0
554
+ remaining_tokens = len(tokens_to_fetch)
555
+ hash_value = []
556
+ while remaining_tokens >= self.page_size:
557
+ last_hash = get_hash_str(
558
+ tokens_to_fetch[
559
+ storage_hit_count : storage_hit_count + self.page_size
560
+ ],
561
+ last_hash,
562
+ )
563
+ if self.storage_backend.exists(last_hash):
564
+ storage_hit_count += self.page_size
565
+ hash_value.append(last_hash)
566
+ remaining_tokens -= self.page_size
567
+ else:
568
+ break
569
+
570
+ if storage_hit_count < self.prefetch_threshold:
571
+ # not to prefetch if not enough benefits
572
+ self.prefetch_revoke_queue.put(operation.request_id)
573
+ else:
574
+ operation.hash_value = hash_value
575
+ logger.debug(
576
+ f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
577
+ )
578
+ self.prefetch_buffer.put(operation)
579
+
580
+ except Empty:
581
+ continue
582
+
583
+ def write_storage(
584
+ self,
585
+ host_indices: torch.Tensor,
586
+ token_ids: List[int],
587
+ last_hash: Optional[str] = None,
588
+ ) -> int:
589
+ """
590
+ Write KV caches from host memory to storage backend.
591
+ """
592
+ operation = StorageOperation(host_indices, token_ids, last_hash)
593
+ self.backup_queue.put(operation)
594
+ return operation.id
595
+
596
+ def backup_thread_func(self):
597
+ """
598
+ Manage backup operations from host memory to storage backend.
599
+ """
600
+ while not self.stop_event.is_set():
601
+ try:
602
+ operation = self.backup_queue.get(block=True, timeout=1)
603
+ if operation is None:
604
+ continue
605
+
606
+ last_hash = operation.last_hash
607
+ tokens_to_backup = operation.token_ids
608
+
609
+ for i in range(0, len(tokens_to_backup), self.page_size):
610
+ last_hash = get_hash_str(
611
+ tokens_to_backup[i : i + self.page_size], last_hash
612
+ )
613
+ # todo, handle failures in storage backend
614
+ self.storage_backend.set(
615
+ last_hash,
616
+ self.mem_pool_host.get_flat_data_page(
617
+ operation.host_indices[i]
618
+ ),
619
+ )
620
+ operation.completed_tokens += self.page_size
621
+ operation.hash_value.append(last_hash)
622
+
623
+ self.ack_backup_queue.put((operation.id, operation.hash_value))
624
+
625
+ except Empty:
626
+ continue
@@ -13,14 +13,14 @@
13
13
  # ==============================================================================
14
14
  """
15
15
  The definition of objects transferred between different
16
- processes (TokenizerManager, DetokenizerManager, Controller).
16
+ processes (TokenizerManager, DetokenizerManager, Scheduler).
17
17
  """
18
18
 
19
19
  import copy
20
20
  import uuid
21
21
  from dataclasses import dataclass, field
22
22
  from enum import Enum
23
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
23
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
24
24
 
25
25
  from sglang.srt.managers.schedule_batch import BaseFinishReason
26
26
  from sglang.srt.multimodal.mm_utils import has_valid_data
@@ -42,8 +42,21 @@ class SessionParams:
42
42
  drop_previous_output: Optional[bool] = None
43
43
 
44
44
 
45
- AudioDataItem = Union[str, Dict]
46
- ImageDataItem = Union[Image, str, Dict]
45
+ # Type definitions for multimodal input data
46
+ # Individual data item types for each modality
47
+ ImageDataInputItem = Union[Image, str, Dict]
48
+ AudioDataInputItem = Union[str, Dict]
49
+ VideoDataInputItem = Union[str, Dict]
50
+ # Union type for any multimodal data item
51
+ MultimodalDataInputItem = Union[
52
+ ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
53
+ ]
54
+ # Format types supporting single items, lists, or nested lists for batch processing
55
+ MultimodalDataInputFormat = Union[
56
+ List[List[MultimodalDataInputItem]],
57
+ List[MultimodalDataInputItem],
58
+ MultimodalDataInputItem,
59
+ ]
47
60
 
48
61
 
49
62
  @dataclass
@@ -60,13 +73,11 @@ class GenerateReqInput:
60
73
  # - List of images (one per request in a batch)
61
74
  # - List of lists of images (multiple images per request)
62
75
  # See also python/sglang/srt/utils.py:load_image for more details.
63
- image_data: Optional[
64
- Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
65
- ] = None
66
- # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
67
- audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
76
+ image_data: Optional[MultimodalDataInputFormat] = None
68
77
  # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
69
- video_data: Optional[Union[List[List[str]], List[str], str]] = None
78
+ video_data: Optional[MultimodalDataInputFormat] = None
79
+ # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
80
+ audio_data: Optional[MultimodalDataInputFormat] = None
70
81
  # The sampling_params. See descriptions below.
71
82
  sampling_params: Optional[Union[List[Dict], Dict]] = None
72
83
  # The request id.
@@ -297,6 +308,9 @@ class GenerateReqInput:
297
308
  self.modalities.append("image")
298
309
  elif len(self.image_data[i]) > 1:
299
310
  self.modalities.append("multi-images")
311
+ else:
312
+ # Ensure len(self.modalities) == len(self.image_data)
313
+ self.modalities.append(None)
300
314
  # Expand parallel_sample_num
301
315
  self.image_data = self.image_data * self.parallel_sample_num
302
316
  self.modalities = self.modalities * self.parallel_sample_num
@@ -521,19 +535,17 @@ class EmbeddingReqInput:
521
535
  # - List of images (one per request in a batch)
522
536
  # - List of lists of images (multiple images per request)
523
537
  # See also python/sglang/srt/utils.py:load_image for more details.
524
- image_data: Optional[
525
- Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
526
- ] = None
538
+ image_data: Optional[MultimodalDataInputFormat] = None
527
539
  # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
528
- video_data: Optional[Union[List[str], str]] = None
540
+ video_data: Optional[MultimodalDataInputFormat] = None
529
541
  # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
530
- audio_data: Optional[Union[List[str], str]] = None
542
+ audio_data: Optional[MultimodalDataInputFormat] = None
531
543
  # The token ids for text; one can either specify text or input_ids.
532
544
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
533
545
  # The request id.
534
546
  rid: Optional[Union[List[str], str]] = None
535
547
  # Dummy sampling params for compatibility
536
- sampling_params: Union[List[Dict], Dict] = None
548
+ sampling_params: Optional[Union[List[Dict], Dict]] = None
537
549
  # Dummy input embeds for compatibility
538
550
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
539
551
  # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
@@ -607,8 +619,6 @@ class EmbeddingReqInput:
607
619
  if self.is_cross_encoder_request:
608
620
  return EmbeddingReqInput(
609
621
  text=[self.text[i]] if self.text is not None else None,
610
- input_ids=None,
611
- image_data=None,
612
622
  sampling_params=self.sampling_params[i],
613
623
  rid=self.rid[i],
614
624
  is_cross_encoder_request=True,
@@ -618,6 +628,8 @@ class EmbeddingReqInput:
618
628
  text=self.text[i] if self.text is not None else None,
619
629
  input_ids=self.input_ids[i] if self.input_ids is not None else None,
620
630
  image_data=self.image_data[i] if self.image_data is not None else None,
631
+ audio_data=self.audio_data[i] if self.audio_data is not None else None,
632
+ video_data=self.video_data[i] if self.video_data is not None else None,
621
633
  sampling_params=self.sampling_params[i],
622
634
  rid=self.rid[i],
623
635
  )
@@ -941,17 +953,6 @@ class ProfileReqType(Enum):
941
953
  STOP_PROFILE = 2
942
954
 
943
955
 
944
- class ExpertDistributionReq(Enum):
945
- START_RECORD = 1
946
- STOP_RECORD = 2
947
- DUMP_RECORD = 3
948
-
949
-
950
- @dataclass
951
- class ExpertDistributionReqOutput:
952
- pass
953
-
954
-
955
956
  @dataclass
956
957
  class ProfileReq:
957
958
  type: ProfileReqType
@@ -1001,6 +1002,17 @@ class HealthCheckOutput:
1001
1002
  pass
1002
1003
 
1003
1004
 
1005
+ class ExpertDistributionReq(Enum):
1006
+ START_RECORD = 1
1007
+ STOP_RECORD = 2
1008
+ DUMP_RECORD = 3
1009
+
1010
+
1011
+ @dataclass
1012
+ class ExpertDistributionReqOutput:
1013
+ pass
1014
+
1015
+
1004
1016
  @dataclass
1005
1017
  class Function:
1006
1018
  description: Optional[str] = None