sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
sglang/srt/lora/mem_pool.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Callable, Dict, List, Optional, Set, Tuple
|
1
|
+
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
@@ -6,10 +6,12 @@ from sglang.srt.distributed import divide
|
|
6
6
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
7
7
|
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
8
8
|
from sglang.srt.lora.lora import LoRAAdapter
|
9
|
+
from sglang.srt.lora.lora_config import LoRAConfig
|
9
10
|
from sglang.srt.lora.utils import (
|
10
11
|
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
11
12
|
LoRAType,
|
12
13
|
get_hidden_dim,
|
14
|
+
get_normalized_lora_weight_names,
|
13
15
|
get_stacked_multiply,
|
14
16
|
get_weight_name,
|
15
17
|
)
|
@@ -25,6 +27,9 @@ class LoRAMemoryPool:
|
|
25
27
|
dtype: torch.dtype,
|
26
28
|
tp_size: int,
|
27
29
|
tp_rank: int,
|
30
|
+
max_lora_rank: int,
|
31
|
+
lora_weight_names: Tuple[Set[str], Set[str]],
|
32
|
+
base_model: torch.nn.Module,
|
28
33
|
):
|
29
34
|
self.base_hf_config: AutoConfig = base_hf_config
|
30
35
|
self.num_layer: int = base_hf_config.num_hidden_layers
|
@@ -32,6 +37,10 @@ class LoRAMemoryPool:
|
|
32
37
|
self.dtype: torch.dtype = dtype
|
33
38
|
self.tp_size: int = tp_size
|
34
39
|
self.tp_rank: int = tp_rank
|
40
|
+
self.max_lora_rank: int = max_lora_rank
|
41
|
+
|
42
|
+
# lora weight names for LoRA A and B respectively.
|
43
|
+
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
|
35
44
|
|
36
45
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
37
46
|
# A_buffer contains num_layer number of row-major tensors with shape
|
@@ -49,6 +58,31 @@ class LoRAMemoryPool:
|
|
49
58
|
# Here we don't initialize to None since None is a valid uid
|
50
59
|
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
51
60
|
|
61
|
+
self.init_buffers(base_model)
|
62
|
+
|
63
|
+
def can_support(self, config: Union[LoRAConfig, Iterable[LoRAConfig]]) -> bool:
|
64
|
+
"""
|
65
|
+
Check if the memory pool can support the given LoRA adapters.
|
66
|
+
"""
|
67
|
+
|
68
|
+
def _can_support(config: LoRAConfig) -> bool:
|
69
|
+
"""
|
70
|
+
Check if the memory pool can support a single LoRA adapter.
|
71
|
+
"""
|
72
|
+
if config.r > self.max_lora_rank:
|
73
|
+
return False
|
74
|
+
weights_a, weights_b = get_normalized_lora_weight_names(
|
75
|
+
config.target_modules
|
76
|
+
)
|
77
|
+
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
|
78
|
+
self.lora_weight_names[1]
|
79
|
+
)
|
80
|
+
|
81
|
+
if isinstance(config, LoRAConfig):
|
82
|
+
return _can_support(config)
|
83
|
+
else:
|
84
|
+
return all(_can_support(x) for x in config)
|
85
|
+
|
52
86
|
def get_lora_A_shape(
|
53
87
|
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
54
88
|
) -> Tuple[int]:
|
@@ -82,25 +116,18 @@ class LoRAMemoryPool:
|
|
82
116
|
max_lora_dim,
|
83
117
|
)
|
84
118
|
|
85
|
-
def init_buffers(
|
86
|
-
self,
|
87
|
-
lora_weight_names: Tuple[Set[str]],
|
88
|
-
base_model: torch.nn.Module,
|
89
|
-
max_lora_dim: int,
|
90
|
-
):
|
91
|
-
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
92
|
-
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
93
|
-
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
|
119
|
+
def init_buffers(self, base_model: torch.nn.Module):
|
94
120
|
device = next(base_model.parameters()).device
|
95
121
|
|
96
|
-
def
|
122
|
+
def init_buffer(
|
97
123
|
buffer: Dict[str, List[torch.Tensor]],
|
98
124
|
lora_weight_names: Set[str],
|
99
125
|
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
100
126
|
):
|
101
|
-
|
102
|
-
|
103
|
-
|
127
|
+
for module_name in lora_weight_names:
|
128
|
+
lora_shape = get_lora_shape_fn(
|
129
|
+
module_name, base_model, self.max_lora_rank
|
130
|
+
)
|
104
131
|
buffer[module_name] = [
|
105
132
|
torch.empty(
|
106
133
|
lora_shape,
|
@@ -110,15 +137,15 @@ class LoRAMemoryPool:
|
|
110
137
|
for _ in range(self.num_layer)
|
111
138
|
]
|
112
139
|
|
113
|
-
|
140
|
+
init_buffer(
|
114
141
|
self.A_buffer,
|
115
|
-
lora_weight_names[0],
|
142
|
+
self.lora_weight_names[0],
|
116
143
|
self.get_lora_A_shape,
|
117
144
|
)
|
118
145
|
|
119
|
-
|
146
|
+
init_buffer(
|
120
147
|
self.B_buffer,
|
121
|
-
lora_weight_names[1],
|
148
|
+
self.lora_weight_names[1],
|
122
149
|
self.get_lora_B_shape,
|
123
150
|
)
|
124
151
|
|
@@ -161,10 +188,18 @@ class LoRAMemoryPool:
|
|
161
188
|
lora_adapter: LoRAAdapter,
|
162
189
|
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
163
190
|
):
|
164
|
-
def
|
165
|
-
|
166
|
-
|
167
|
-
|
191
|
+
def load_lora_weight_tensor(
|
192
|
+
buffer_view: torch.Tensor, weight: Optional[torch.Tensor]
|
193
|
+
):
|
194
|
+
if weight is None:
|
195
|
+
# If the particular weight is not present in the adapter, we initialize the buffer to zero
|
196
|
+
# to avoid contamination from the residual weight of the evicted adapters.
|
197
|
+
buffer_view.zero_()
|
198
|
+
else:
|
199
|
+
assert (
|
200
|
+
buffer_view.shape == weight.shape
|
201
|
+
), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
|
202
|
+
buffer_view.copy_(weight)
|
168
203
|
|
169
204
|
if uid is None:
|
170
205
|
for i in range(self.num_layer):
|
@@ -176,8 +211,12 @@ class LoRAMemoryPool:
|
|
176
211
|
lora_rank = lora_adapter.config.hf_config["r"]
|
177
212
|
for layer_id in range(self.num_layer):
|
178
213
|
layer_weights = lora_adapter.layers[layer_id].weights
|
179
|
-
temp_A_buffer: Dict[str, torch.Tensor] = {
|
180
|
-
|
214
|
+
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
|
215
|
+
weight_name: None for weight_name in self.A_buffer
|
216
|
+
}
|
217
|
+
temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
|
218
|
+
weight_name: None for weight_name in self.B_buffer
|
219
|
+
}
|
181
220
|
for name, weights in layer_weights.items():
|
182
221
|
if "lora_A" in name:
|
183
222
|
lora_weight_name = get_weight_name(
|
@@ -193,6 +232,14 @@ class LoRAMemoryPool:
|
|
193
232
|
if self.tp_size > 1:
|
194
233
|
cur_layer_modules = lora_modules[layer_id]
|
195
234
|
for module_name, module in cur_layer_modules.items():
|
235
|
+
weight_name = get_weight_name(
|
236
|
+
module_name, self.lora_weight_names, LoRAType.LORA_A
|
237
|
+
)
|
238
|
+
|
239
|
+
if temp_A_buffer[weight_name] is None:
|
240
|
+
# Skip weight slicing if the weight is not present in the adapter
|
241
|
+
continue
|
242
|
+
|
196
243
|
if "qkv_proj" in module_name:
|
197
244
|
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
|
198
245
|
temp_A_buffer["qkv_proj"], self.tp_rank
|
@@ -204,9 +251,10 @@ class LoRAMemoryPool:
|
|
204
251
|
)
|
205
252
|
)
|
206
253
|
else:
|
207
|
-
|
208
|
-
|
209
|
-
|
254
|
+
# TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
|
255
|
+
# Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
|
256
|
+
# B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
|
257
|
+
# FlashInfer LoRA backend.
|
210
258
|
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
211
259
|
temp_A_buffer[weight_name], self.tp_rank
|
212
260
|
)
|
@@ -219,8 +267,7 @@ class LoRAMemoryPool:
|
|
219
267
|
buffer_view = self.A_buffer[name][layer_id][buffer_id][
|
220
268
|
: lora_rank * c, :
|
221
269
|
]
|
222
|
-
|
223
|
-
buffer_view.copy_(weights)
|
270
|
+
load_lora_weight_tensor(buffer_view, weights)
|
224
271
|
|
225
272
|
for name, weights in temp_B_buffer.items():
|
226
273
|
c = get_stacked_multiply(name)
|
@@ -229,14 +276,15 @@ class LoRAMemoryPool:
|
|
229
276
|
buffer_view = self.B_buffer[name][layer_id][stacked_id][
|
230
277
|
buffer_id
|
231
278
|
][:, :lora_rank]
|
232
|
-
|
233
|
-
|
279
|
+
weight_slice = (
|
280
|
+
weights[stacked_id] if weights is not None else None
|
281
|
+
)
|
282
|
+
load_lora_weight_tensor(buffer_view, weight_slice)
|
234
283
|
else:
|
235
284
|
buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
|
236
285
|
:, :lora_rank
|
237
286
|
]
|
238
|
-
|
239
|
-
buffer_view.copy_(weights)
|
287
|
+
load_lora_weight_tensor(buffer_view, weights)
|
240
288
|
|
241
289
|
def get_tensor(
|
242
290
|
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
sglang/srt/lora/utils.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import re
|
2
2
|
from dataclasses import dataclass
|
3
3
|
from enum import Enum
|
4
|
-
from typing import
|
4
|
+
from typing import Iterable, Optional, Set, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
@@ -106,9 +106,11 @@ def get_hidden_dim(
|
|
106
106
|
raise NotImplementedError()
|
107
107
|
|
108
108
|
|
109
|
-
def get_normalized_lora_weight_names(
|
109
|
+
def get_normalized_lora_weight_names(
|
110
|
+
target_modules: Iterable[str],
|
111
|
+
) -> Tuple[set[str], set[str]]:
|
110
112
|
"""
|
111
|
-
Mapping a target module name to names of the normalized LoRA weights.
|
113
|
+
Mapping a list of target module name to names of the normalized LoRA weights.
|
112
114
|
Returned tuple contains (name for Lora A, name for Lora B)
|
113
115
|
"""
|
114
116
|
params_mapping = {
|
@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
|
120
122
|
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
121
123
|
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
122
124
|
}
|
123
|
-
|
124
|
-
|
125
|
+
|
126
|
+
result = (set(), set())
|
127
|
+
for name in target_modules:
|
128
|
+
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
|
129
|
+
result[0].update(lora_a)
|
130
|
+
result[1].update(lora_b)
|
131
|
+
return result
|
125
132
|
|
126
133
|
|
127
134
|
def get_stacked_multiply(module_name: str) -> int:
|
@@ -25,6 +25,8 @@ if TYPE_CHECKING:
|
|
25
25
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
26
26
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
27
27
|
|
28
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
29
|
+
|
28
30
|
logger = logging.getLogger(__name__)
|
29
31
|
|
30
32
|
|
@@ -159,6 +161,57 @@ class TransferBuffer:
|
|
159
161
|
self.buffers.queue.clear()
|
160
162
|
|
161
163
|
|
164
|
+
class StorageOperation:
|
165
|
+
counter = 0
|
166
|
+
|
167
|
+
def __init__(
|
168
|
+
self,
|
169
|
+
host_indices: torch.Tensor,
|
170
|
+
token_ids: List[int],
|
171
|
+
last_hash: Optional[str] = None,
|
172
|
+
):
|
173
|
+
self.host_indices = host_indices
|
174
|
+
self.token_ids = token_ids
|
175
|
+
self.last_hash = last_hash
|
176
|
+
self.completed_tokens = 0
|
177
|
+
self.hash_value = []
|
178
|
+
|
179
|
+
self.id = StorageOperation.counter
|
180
|
+
StorageOperation.counter += 1
|
181
|
+
|
182
|
+
def __lt__(self, other: "StorageOperation"):
|
183
|
+
return self.id < other.id
|
184
|
+
|
185
|
+
|
186
|
+
class PrefetchOperation(StorageOperation):
|
187
|
+
def __init__(
|
188
|
+
self,
|
189
|
+
request_id: str,
|
190
|
+
host_indices: torch.Tensor,
|
191
|
+
token_ids: List[int],
|
192
|
+
last_hash: Optional[str] = None,
|
193
|
+
):
|
194
|
+
self.request_id = request_id
|
195
|
+
|
196
|
+
self._done_flag = False
|
197
|
+
self._lock = threading.Lock()
|
198
|
+
|
199
|
+
super().__init__(host_indices, token_ids, last_hash)
|
200
|
+
|
201
|
+
def increment(self, num_tokens: int):
|
202
|
+
with self._lock:
|
203
|
+
if self._done_flag:
|
204
|
+
return
|
205
|
+
self.completed_tokens += num_tokens
|
206
|
+
|
207
|
+
def mark_done(self):
|
208
|
+
with self._lock:
|
209
|
+
self._done_flag = True
|
210
|
+
|
211
|
+
def is_done(self) -> bool:
|
212
|
+
return self._done_flag
|
213
|
+
|
214
|
+
|
162
215
|
class HiCacheController:
|
163
216
|
|
164
217
|
def __init__(
|
@@ -169,6 +222,8 @@ class HiCacheController:
|
|
169
222
|
load_cache_event: threading.Event = None,
|
170
223
|
write_policy: str = "write_through_selective",
|
171
224
|
io_backend: str = "",
|
225
|
+
storage_backend: Optional[str] = None,
|
226
|
+
prefetch_threshold: int = 256,
|
172
227
|
):
|
173
228
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
174
229
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
@@ -186,6 +241,19 @@ class HiCacheController:
|
|
186
241
|
else:
|
187
242
|
self.io_backend = io_backend
|
188
243
|
|
244
|
+
self.enable_storage = False
|
245
|
+
# todo: move backend initialization to storage backend module
|
246
|
+
if storage_backend is not None:
|
247
|
+
if storage_backend == "file":
|
248
|
+
self.storage_backend = HiCacheFile()
|
249
|
+
self.enable_storage = True
|
250
|
+
# todo: threshold policy for prefetching
|
251
|
+
self.prefetch_threshold = prefetch_threshold
|
252
|
+
else:
|
253
|
+
raise NotImplementedError(
|
254
|
+
f"Unsupported storage backend: {storage_backend}"
|
255
|
+
)
|
256
|
+
|
189
257
|
self.load_cache_event = load_cache_event
|
190
258
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
191
259
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
@@ -218,9 +286,26 @@ class HiCacheController:
|
|
218
286
|
self.load_thread = threading.Thread(
|
219
287
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
220
288
|
)
|
289
|
+
|
221
290
|
self.write_thread.start()
|
222
291
|
self.load_thread.start()
|
223
292
|
|
293
|
+
if self.enable_storage:
|
294
|
+
self.prefetch_thread = threading.Thread(
|
295
|
+
target=self.prefetch_thread_func, daemon=True
|
296
|
+
)
|
297
|
+
self.backup_thread = threading.Thread(
|
298
|
+
target=self.backup_thread_func, daemon=True
|
299
|
+
)
|
300
|
+
self.prefetch_queue = Queue()
|
301
|
+
self.backup_queue = Queue()
|
302
|
+
|
303
|
+
self.prefetch_revoke_queue = Queue()
|
304
|
+
self.ack_backup_queue = Queue()
|
305
|
+
|
306
|
+
self.prefetch_thread.start()
|
307
|
+
self.backup_thread.start()
|
308
|
+
|
224
309
|
def reset(self):
|
225
310
|
self.stop_event.set()
|
226
311
|
self.write_thread.join()
|
@@ -232,6 +317,13 @@ class HiCacheController:
|
|
232
317
|
self.load_buffer.clear()
|
233
318
|
self.ack_write_queue.queue.clear()
|
234
319
|
self.ack_load_queue.queue.clear()
|
320
|
+
if self.enable_storage:
|
321
|
+
self.prefetch_thread.join()
|
322
|
+
self.backup_thread.join()
|
323
|
+
self.prefetch_queue.queue.clear()
|
324
|
+
self.backup_queue.queue.clear()
|
325
|
+
self.prefetch_revoke_queue.queue.clear()
|
326
|
+
self.ack_backup_queue.queue.clear()
|
235
327
|
|
236
328
|
self.write_thread = threading.Thread(
|
237
329
|
target=self.write_thread_func_direct, daemon=True
|
@@ -243,6 +335,16 @@ class HiCacheController:
|
|
243
335
|
self.write_thread.start()
|
244
336
|
self.load_thread.start()
|
245
337
|
|
338
|
+
if self.enable_storage:
|
339
|
+
self.prefetch_thread = threading.Thread(
|
340
|
+
target=self.prefetch_thread_func, daemon=True
|
341
|
+
)
|
342
|
+
self.backup_thread = threading.Thread(
|
343
|
+
target=self.backup_thread_func, daemon=True
|
344
|
+
)
|
345
|
+
self.prefetch_thread.start()
|
346
|
+
self.backup_thread.start()
|
347
|
+
|
246
348
|
def write(
|
247
349
|
self,
|
248
350
|
device_indices: torch.Tensor,
|
@@ -383,3 +485,142 @@ class HiCacheController:
|
|
383
485
|
raise ValueError(
|
384
486
|
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
385
487
|
)
|
488
|
+
|
489
|
+
def prefetch(
|
490
|
+
self,
|
491
|
+
request_id: str,
|
492
|
+
host_indices: torch.Tensor,
|
493
|
+
new_input_tokens: List[int],
|
494
|
+
last_hash: Optional[str] = None,
|
495
|
+
) -> int:
|
496
|
+
"""
|
497
|
+
Prefetch KV caches from storage backend to host memory.
|
498
|
+
"""
|
499
|
+
operation = PrefetchOperation(
|
500
|
+
request_id, host_indices, new_input_tokens, last_hash
|
501
|
+
)
|
502
|
+
self.prefetch_queue.put(operation)
|
503
|
+
return operation
|
504
|
+
|
505
|
+
def terminate_prefetch(self, operation):
|
506
|
+
operation.mark_done()
|
507
|
+
return operation.completed_tokens, operation.hash_value
|
508
|
+
|
509
|
+
def prefetch_io_aux_func(self):
|
510
|
+
"""
|
511
|
+
Auxiliary function conducting IO operations for prefetching.
|
512
|
+
"""
|
513
|
+
while not self.stop_event.is_set():
|
514
|
+
try:
|
515
|
+
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
516
|
+
for h in operation.hash_value:
|
517
|
+
page_data = self.storage_backend.get(h)
|
518
|
+
if page_data is None:
|
519
|
+
logger.warning(
|
520
|
+
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
521
|
+
)
|
522
|
+
break
|
523
|
+
self.mem_pool_host.set_from_flat_data_page(
|
524
|
+
operation.host_indices[operation.completed_tokens],
|
525
|
+
page_data,
|
526
|
+
)
|
527
|
+
operation.increment(self.page_size)
|
528
|
+
if operation.is_done():
|
529
|
+
# operation terminated by controller, release pre-allocated memory
|
530
|
+
self.mem_pool_host.free(
|
531
|
+
operation.host_indices[operation.completed_tokens :]
|
532
|
+
)
|
533
|
+
break
|
534
|
+
except Empty:
|
535
|
+
continue
|
536
|
+
|
537
|
+
def prefetch_thread_func(self):
|
538
|
+
"""
|
539
|
+
Manage prefetching operations from storage backend to host memory.
|
540
|
+
"""
|
541
|
+
self.prefetch_buffer = Queue()
|
542
|
+
aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
|
543
|
+
aux_thread.start()
|
544
|
+
while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
|
545
|
+
try:
|
546
|
+
operation = self.prefetch_queue.get(block=True, timeout=1)
|
547
|
+
if operation is None:
|
548
|
+
continue
|
549
|
+
|
550
|
+
last_hash = operation.last_hash
|
551
|
+
tokens_to_fetch = operation.token_ids
|
552
|
+
|
553
|
+
storage_hit_count = 0
|
554
|
+
remaining_tokens = len(tokens_to_fetch)
|
555
|
+
hash_value = []
|
556
|
+
while remaining_tokens >= self.page_size:
|
557
|
+
last_hash = get_hash_str(
|
558
|
+
tokens_to_fetch[
|
559
|
+
storage_hit_count : storage_hit_count + self.page_size
|
560
|
+
],
|
561
|
+
last_hash,
|
562
|
+
)
|
563
|
+
if self.storage_backend.exists(last_hash):
|
564
|
+
storage_hit_count += self.page_size
|
565
|
+
hash_value.append(last_hash)
|
566
|
+
remaining_tokens -= self.page_size
|
567
|
+
else:
|
568
|
+
break
|
569
|
+
|
570
|
+
if storage_hit_count < self.prefetch_threshold:
|
571
|
+
# not to prefetch if not enough benefits
|
572
|
+
self.prefetch_revoke_queue.put(operation.request_id)
|
573
|
+
else:
|
574
|
+
operation.hash_value = hash_value
|
575
|
+
logger.debug(
|
576
|
+
f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
|
577
|
+
)
|
578
|
+
self.prefetch_buffer.put(operation)
|
579
|
+
|
580
|
+
except Empty:
|
581
|
+
continue
|
582
|
+
|
583
|
+
def write_storage(
|
584
|
+
self,
|
585
|
+
host_indices: torch.Tensor,
|
586
|
+
token_ids: List[int],
|
587
|
+
last_hash: Optional[str] = None,
|
588
|
+
) -> int:
|
589
|
+
"""
|
590
|
+
Write KV caches from host memory to storage backend.
|
591
|
+
"""
|
592
|
+
operation = StorageOperation(host_indices, token_ids, last_hash)
|
593
|
+
self.backup_queue.put(operation)
|
594
|
+
return operation.id
|
595
|
+
|
596
|
+
def backup_thread_func(self):
|
597
|
+
"""
|
598
|
+
Manage backup operations from host memory to storage backend.
|
599
|
+
"""
|
600
|
+
while not self.stop_event.is_set():
|
601
|
+
try:
|
602
|
+
operation = self.backup_queue.get(block=True, timeout=1)
|
603
|
+
if operation is None:
|
604
|
+
continue
|
605
|
+
|
606
|
+
last_hash = operation.last_hash
|
607
|
+
tokens_to_backup = operation.token_ids
|
608
|
+
|
609
|
+
for i in range(0, len(tokens_to_backup), self.page_size):
|
610
|
+
last_hash = get_hash_str(
|
611
|
+
tokens_to_backup[i : i + self.page_size], last_hash
|
612
|
+
)
|
613
|
+
# todo, handle failures in storage backend
|
614
|
+
self.storage_backend.set(
|
615
|
+
last_hash,
|
616
|
+
self.mem_pool_host.get_flat_data_page(
|
617
|
+
operation.host_indices[i]
|
618
|
+
),
|
619
|
+
)
|
620
|
+
operation.completed_tokens += self.page_size
|
621
|
+
operation.hash_value.append(last_hash)
|
622
|
+
|
623
|
+
self.ack_backup_queue.put((operation.id, operation.hash_value))
|
624
|
+
|
625
|
+
except Empty:
|
626
|
+
continue
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -13,14 +13,14 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""
|
15
15
|
The definition of objects transferred between different
|
16
|
-
processes (TokenizerManager, DetokenizerManager,
|
16
|
+
processes (TokenizerManager, DetokenizerManager, Scheduler).
|
17
17
|
"""
|
18
18
|
|
19
19
|
import copy
|
20
20
|
import uuid
|
21
21
|
from dataclasses import dataclass, field
|
22
22
|
from enum import Enum
|
23
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional,
|
23
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
24
24
|
|
25
25
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
26
26
|
from sglang.srt.multimodal.mm_utils import has_valid_data
|
@@ -42,8 +42,21 @@ class SessionParams:
|
|
42
42
|
drop_previous_output: Optional[bool] = None
|
43
43
|
|
44
44
|
|
45
|
-
|
46
|
-
|
45
|
+
# Type definitions for multimodal input data
|
46
|
+
# Individual data item types for each modality
|
47
|
+
ImageDataInputItem = Union[Image, str, Dict]
|
48
|
+
AudioDataInputItem = Union[str, Dict]
|
49
|
+
VideoDataInputItem = Union[str, Dict]
|
50
|
+
# Union type for any multimodal data item
|
51
|
+
MultimodalDataInputItem = Union[
|
52
|
+
ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
|
53
|
+
]
|
54
|
+
# Format types supporting single items, lists, or nested lists for batch processing
|
55
|
+
MultimodalDataInputFormat = Union[
|
56
|
+
List[List[MultimodalDataInputItem]],
|
57
|
+
List[MultimodalDataInputItem],
|
58
|
+
MultimodalDataInputItem,
|
59
|
+
]
|
47
60
|
|
48
61
|
|
49
62
|
@dataclass
|
@@ -60,13 +73,11 @@ class GenerateReqInput:
|
|
60
73
|
# - List of images (one per request in a batch)
|
61
74
|
# - List of lists of images (multiple images per request)
|
62
75
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
63
|
-
image_data: Optional[
|
64
|
-
Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
|
65
|
-
] = None
|
66
|
-
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
67
|
-
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
|
76
|
+
image_data: Optional[MultimodalDataInputFormat] = None
|
68
77
|
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
69
|
-
video_data: Optional[
|
78
|
+
video_data: Optional[MultimodalDataInputFormat] = None
|
79
|
+
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
80
|
+
audio_data: Optional[MultimodalDataInputFormat] = None
|
70
81
|
# The sampling_params. See descriptions below.
|
71
82
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
72
83
|
# The request id.
|
@@ -297,6 +308,9 @@ class GenerateReqInput:
|
|
297
308
|
self.modalities.append("image")
|
298
309
|
elif len(self.image_data[i]) > 1:
|
299
310
|
self.modalities.append("multi-images")
|
311
|
+
else:
|
312
|
+
# Ensure len(self.modalities) == len(self.image_data)
|
313
|
+
self.modalities.append(None)
|
300
314
|
# Expand parallel_sample_num
|
301
315
|
self.image_data = self.image_data * self.parallel_sample_num
|
302
316
|
self.modalities = self.modalities * self.parallel_sample_num
|
@@ -521,19 +535,17 @@ class EmbeddingReqInput:
|
|
521
535
|
# - List of images (one per request in a batch)
|
522
536
|
# - List of lists of images (multiple images per request)
|
523
537
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
524
|
-
image_data: Optional[
|
525
|
-
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
526
|
-
] = None
|
538
|
+
image_data: Optional[MultimodalDataInputFormat] = None
|
527
539
|
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
528
|
-
video_data: Optional[
|
540
|
+
video_data: Optional[MultimodalDataInputFormat] = None
|
529
541
|
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
530
|
-
audio_data: Optional[
|
542
|
+
audio_data: Optional[MultimodalDataInputFormat] = None
|
531
543
|
# The token ids for text; one can either specify text or input_ids.
|
532
544
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
533
545
|
# The request id.
|
534
546
|
rid: Optional[Union[List[str], str]] = None
|
535
547
|
# Dummy sampling params for compatibility
|
536
|
-
sampling_params: Union[List[Dict], Dict] = None
|
548
|
+
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
537
549
|
# Dummy input embeds for compatibility
|
538
550
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
539
551
|
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
@@ -607,8 +619,6 @@ class EmbeddingReqInput:
|
|
607
619
|
if self.is_cross_encoder_request:
|
608
620
|
return EmbeddingReqInput(
|
609
621
|
text=[self.text[i]] if self.text is not None else None,
|
610
|
-
input_ids=None,
|
611
|
-
image_data=None,
|
612
622
|
sampling_params=self.sampling_params[i],
|
613
623
|
rid=self.rid[i],
|
614
624
|
is_cross_encoder_request=True,
|
@@ -618,6 +628,8 @@ class EmbeddingReqInput:
|
|
618
628
|
text=self.text[i] if self.text is not None else None,
|
619
629
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
620
630
|
image_data=self.image_data[i] if self.image_data is not None else None,
|
631
|
+
audio_data=self.audio_data[i] if self.audio_data is not None else None,
|
632
|
+
video_data=self.video_data[i] if self.video_data is not None else None,
|
621
633
|
sampling_params=self.sampling_params[i],
|
622
634
|
rid=self.rid[i],
|
623
635
|
)
|
@@ -941,17 +953,6 @@ class ProfileReqType(Enum):
|
|
941
953
|
STOP_PROFILE = 2
|
942
954
|
|
943
955
|
|
944
|
-
class ExpertDistributionReq(Enum):
|
945
|
-
START_RECORD = 1
|
946
|
-
STOP_RECORD = 2
|
947
|
-
DUMP_RECORD = 3
|
948
|
-
|
949
|
-
|
950
|
-
@dataclass
|
951
|
-
class ExpertDistributionReqOutput:
|
952
|
-
pass
|
953
|
-
|
954
|
-
|
955
956
|
@dataclass
|
956
957
|
class ProfileReq:
|
957
958
|
type: ProfileReqType
|
@@ -1001,6 +1002,17 @@ class HealthCheckOutput:
|
|
1001
1002
|
pass
|
1002
1003
|
|
1003
1004
|
|
1005
|
+
class ExpertDistributionReq(Enum):
|
1006
|
+
START_RECORD = 1
|
1007
|
+
STOP_RECORD = 2
|
1008
|
+
DUMP_RECORD = 3
|
1009
|
+
|
1010
|
+
|
1011
|
+
@dataclass
|
1012
|
+
class ExpertDistributionReqOutput:
|
1013
|
+
pass
|
1014
|
+
|
1015
|
+
|
1004
1016
|
@dataclass
|
1005
1017
|
class Function:
|
1006
1018
|
description: Optional[str] = None
|