sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
sglang/srt/lora/mem_pool.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Callable, Dict, List, Optional, Set, Tuple
|
1
|
+
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
@@ -6,10 +6,12 @@ from sglang.srt.distributed import divide
|
|
6
6
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
7
7
|
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
8
8
|
from sglang.srt.lora.lora import LoRAAdapter
|
9
|
+
from sglang.srt.lora.lora_config import LoRAConfig
|
9
10
|
from sglang.srt.lora.utils import (
|
10
11
|
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
11
12
|
LoRAType,
|
12
13
|
get_hidden_dim,
|
14
|
+
get_normalized_lora_weight_names,
|
13
15
|
get_stacked_multiply,
|
14
16
|
get_weight_name,
|
15
17
|
)
|
@@ -25,6 +27,9 @@ class LoRAMemoryPool:
|
|
25
27
|
dtype: torch.dtype,
|
26
28
|
tp_size: int,
|
27
29
|
tp_rank: int,
|
30
|
+
max_lora_rank: int,
|
31
|
+
lora_weight_names: Tuple[Set[str], Set[str]],
|
32
|
+
base_model: torch.nn.Module,
|
28
33
|
):
|
29
34
|
self.base_hf_config: AutoConfig = base_hf_config
|
30
35
|
self.num_layer: int = base_hf_config.num_hidden_layers
|
@@ -32,6 +37,10 @@ class LoRAMemoryPool:
|
|
32
37
|
self.dtype: torch.dtype = dtype
|
33
38
|
self.tp_size: int = tp_size
|
34
39
|
self.tp_rank: int = tp_rank
|
40
|
+
self.max_lora_rank: int = max_lora_rank
|
41
|
+
|
42
|
+
# lora weight names for LoRA A and B respectively.
|
43
|
+
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
|
35
44
|
|
36
45
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
37
46
|
# A_buffer contains num_layer number of row-major tensors with shape
|
@@ -49,6 +58,31 @@ class LoRAMemoryPool:
|
|
49
58
|
# Here we don't initialize to None since None is a valid uid
|
50
59
|
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
51
60
|
|
61
|
+
self.init_buffers(base_model)
|
62
|
+
|
63
|
+
def can_support(self, config: Union[LoRAConfig, Iterable[LoRAConfig]]) -> bool:
|
64
|
+
"""
|
65
|
+
Check if the memory pool can support the given LoRA adapters.
|
66
|
+
"""
|
67
|
+
|
68
|
+
def _can_support(config: LoRAConfig) -> bool:
|
69
|
+
"""
|
70
|
+
Check if the memory pool can support a single LoRA adapter.
|
71
|
+
"""
|
72
|
+
if config.r > self.max_lora_rank:
|
73
|
+
return False
|
74
|
+
weights_a, weights_b = get_normalized_lora_weight_names(
|
75
|
+
config.target_modules
|
76
|
+
)
|
77
|
+
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
|
78
|
+
self.lora_weight_names[1]
|
79
|
+
)
|
80
|
+
|
81
|
+
if isinstance(config, LoRAConfig):
|
82
|
+
return _can_support(config)
|
83
|
+
else:
|
84
|
+
return all(_can_support(x) for x in config)
|
85
|
+
|
52
86
|
def get_lora_A_shape(
|
53
87
|
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
54
88
|
) -> Tuple[int]:
|
@@ -82,25 +116,18 @@ class LoRAMemoryPool:
|
|
82
116
|
max_lora_dim,
|
83
117
|
)
|
84
118
|
|
85
|
-
def init_buffers(
|
86
|
-
self,
|
87
|
-
lora_weight_names: Tuple[Set[str]],
|
88
|
-
base_model: torch.nn.Module,
|
89
|
-
max_lora_dim: int,
|
90
|
-
):
|
91
|
-
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
92
|
-
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
93
|
-
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
|
119
|
+
def init_buffers(self, base_model: torch.nn.Module):
|
94
120
|
device = next(base_model.parameters()).device
|
95
121
|
|
96
|
-
def
|
122
|
+
def init_buffer(
|
97
123
|
buffer: Dict[str, List[torch.Tensor]],
|
98
124
|
lora_weight_names: Set[str],
|
99
125
|
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
100
126
|
):
|
101
|
-
|
102
|
-
|
103
|
-
|
127
|
+
for module_name in lora_weight_names:
|
128
|
+
lora_shape = get_lora_shape_fn(
|
129
|
+
module_name, base_model, self.max_lora_rank
|
130
|
+
)
|
104
131
|
buffer[module_name] = [
|
105
132
|
torch.empty(
|
106
133
|
lora_shape,
|
@@ -110,15 +137,15 @@ class LoRAMemoryPool:
|
|
110
137
|
for _ in range(self.num_layer)
|
111
138
|
]
|
112
139
|
|
113
|
-
|
140
|
+
init_buffer(
|
114
141
|
self.A_buffer,
|
115
|
-
lora_weight_names[0],
|
142
|
+
self.lora_weight_names[0],
|
116
143
|
self.get_lora_A_shape,
|
117
144
|
)
|
118
145
|
|
119
|
-
|
146
|
+
init_buffer(
|
120
147
|
self.B_buffer,
|
121
|
-
lora_weight_names[1],
|
148
|
+
self.lora_weight_names[1],
|
122
149
|
self.get_lora_B_shape,
|
123
150
|
)
|
124
151
|
|
@@ -126,7 +153,7 @@ class LoRAMemoryPool:
|
|
126
153
|
self,
|
127
154
|
cur_uids: Set[Optional[str]],
|
128
155
|
lora_adapters: Dict[str, LoRAAdapter],
|
129
|
-
lora_modules:
|
156
|
+
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
|
130
157
|
):
|
131
158
|
def get_available_buffer_slot():
|
132
159
|
for buffer_id in range(self.max_loras_per_batch):
|
@@ -159,12 +186,20 @@ class LoRAMemoryPool:
|
|
159
186
|
uid: str,
|
160
187
|
buffer_id: int,
|
161
188
|
lora_adapter: LoRAAdapter,
|
162
|
-
lora_modules:
|
189
|
+
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
|
163
190
|
):
|
164
|
-
def
|
165
|
-
|
166
|
-
|
167
|
-
|
191
|
+
def load_lora_weight_tensor(
|
192
|
+
buffer_view: torch.Tensor, weight: Optional[torch.Tensor]
|
193
|
+
):
|
194
|
+
if weight is None:
|
195
|
+
# If the particular weight is not present in the adapter, we initialize the buffer to zero
|
196
|
+
# to avoid contamination from the residual weight of the evicted adapters.
|
197
|
+
buffer_view.zero_()
|
198
|
+
else:
|
199
|
+
assert (
|
200
|
+
buffer_view.shape == weight.shape
|
201
|
+
), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
|
202
|
+
buffer_view.copy_(weight)
|
168
203
|
|
169
204
|
if uid is None:
|
170
205
|
for i in range(self.num_layer):
|
@@ -176,8 +211,12 @@ class LoRAMemoryPool:
|
|
176
211
|
lora_rank = lora_adapter.config.hf_config["r"]
|
177
212
|
for layer_id in range(self.num_layer):
|
178
213
|
layer_weights = lora_adapter.layers[layer_id].weights
|
179
|
-
temp_A_buffer: Dict[str, torch.Tensor] = {
|
180
|
-
|
214
|
+
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
|
215
|
+
weight_name: None for weight_name in self.A_buffer
|
216
|
+
}
|
217
|
+
temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
|
218
|
+
weight_name: None for weight_name in self.B_buffer
|
219
|
+
}
|
181
220
|
for name, weights in layer_weights.items():
|
182
221
|
if "lora_A" in name:
|
183
222
|
lora_weight_name = get_weight_name(
|
@@ -193,6 +232,14 @@ class LoRAMemoryPool:
|
|
193
232
|
if self.tp_size > 1:
|
194
233
|
cur_layer_modules = lora_modules[layer_id]
|
195
234
|
for module_name, module in cur_layer_modules.items():
|
235
|
+
weight_name = get_weight_name(
|
236
|
+
module_name, self.lora_weight_names, LoRAType.LORA_A
|
237
|
+
)
|
238
|
+
|
239
|
+
if temp_A_buffer[weight_name] is None:
|
240
|
+
# Skip weight slicing if the weight is not present in the adapter
|
241
|
+
continue
|
242
|
+
|
196
243
|
if "qkv_proj" in module_name:
|
197
244
|
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
|
198
245
|
temp_A_buffer["qkv_proj"], self.tp_rank
|
@@ -204,9 +251,10 @@ class LoRAMemoryPool:
|
|
204
251
|
)
|
205
252
|
)
|
206
253
|
else:
|
207
|
-
|
208
|
-
|
209
|
-
|
254
|
+
# TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
|
255
|
+
# Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
|
256
|
+
# B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
|
257
|
+
# FlashInfer LoRA backend.
|
210
258
|
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
211
259
|
temp_A_buffer[weight_name], self.tp_rank
|
212
260
|
)
|
@@ -219,8 +267,7 @@ class LoRAMemoryPool:
|
|
219
267
|
buffer_view = self.A_buffer[name][layer_id][buffer_id][
|
220
268
|
: lora_rank * c, :
|
221
269
|
]
|
222
|
-
|
223
|
-
buffer_view.copy_(weights)
|
270
|
+
load_lora_weight_tensor(buffer_view, weights)
|
224
271
|
|
225
272
|
for name, weights in temp_B_buffer.items():
|
226
273
|
c = get_stacked_multiply(name)
|
@@ -229,14 +276,15 @@ class LoRAMemoryPool:
|
|
229
276
|
buffer_view = self.B_buffer[name][layer_id][stacked_id][
|
230
277
|
buffer_id
|
231
278
|
][:, :lora_rank]
|
232
|
-
|
233
|
-
|
279
|
+
weight_slice = (
|
280
|
+
weights[stacked_id] if weights is not None else None
|
281
|
+
)
|
282
|
+
load_lora_weight_tensor(buffer_view, weight_slice)
|
234
283
|
else:
|
235
284
|
buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
|
236
285
|
:, :lora_rank
|
237
286
|
]
|
238
|
-
|
239
|
-
buffer_view.copy_(weights)
|
287
|
+
load_lora_weight_tensor(buffer_view, weights)
|
240
288
|
|
241
289
|
def get_tensor(
|
242
290
|
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
sglang/srt/lora/utils.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import re
|
2
2
|
from dataclasses import dataclass
|
3
3
|
from enum import Enum
|
4
|
-
from typing import
|
4
|
+
from typing import Iterable, Optional, Set, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
@@ -106,9 +106,11 @@ def get_hidden_dim(
|
|
106
106
|
raise NotImplementedError()
|
107
107
|
|
108
108
|
|
109
|
-
def get_normalized_lora_weight_names(
|
109
|
+
def get_normalized_lora_weight_names(
|
110
|
+
target_modules: Iterable[str],
|
111
|
+
) -> Tuple[set[str], set[str]]:
|
110
112
|
"""
|
111
|
-
Mapping a target module name to names of the normalized LoRA weights.
|
113
|
+
Mapping a list of target module name to names of the normalized LoRA weights.
|
112
114
|
Returned tuple contains (name for Lora A, name for Lora B)
|
113
115
|
"""
|
114
116
|
params_mapping = {
|
@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
|
120
122
|
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
121
123
|
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
122
124
|
}
|
123
|
-
|
124
|
-
|
125
|
+
|
126
|
+
result = (set(), set())
|
127
|
+
for name in target_modules:
|
128
|
+
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
|
129
|
+
result[0].update(lora_a)
|
130
|
+
result[1].update(lora_b)
|
131
|
+
return result
|
125
132
|
|
126
133
|
|
127
134
|
def get_stacked_multiply(module_name: str) -> int:
|
@@ -25,6 +25,8 @@ if TYPE_CHECKING:
|
|
25
25
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
26
26
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
27
27
|
|
28
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
29
|
+
|
28
30
|
logger = logging.getLogger(__name__)
|
29
31
|
|
30
32
|
|
@@ -159,6 +161,57 @@ class TransferBuffer:
|
|
159
161
|
self.buffers.queue.clear()
|
160
162
|
|
161
163
|
|
164
|
+
class StorageOperation:
|
165
|
+
counter = 0
|
166
|
+
|
167
|
+
def __init__(
|
168
|
+
self,
|
169
|
+
host_indices: torch.Tensor,
|
170
|
+
token_ids: List[int],
|
171
|
+
last_hash: Optional[str] = None,
|
172
|
+
):
|
173
|
+
self.host_indices = host_indices
|
174
|
+
self.token_ids = token_ids
|
175
|
+
self.last_hash = last_hash
|
176
|
+
self.completed_tokens = 0
|
177
|
+
self.hash_value = []
|
178
|
+
|
179
|
+
self.id = StorageOperation.counter
|
180
|
+
StorageOperation.counter += 1
|
181
|
+
|
182
|
+
def __lt__(self, other: "StorageOperation"):
|
183
|
+
return self.id < other.id
|
184
|
+
|
185
|
+
|
186
|
+
class PrefetchOperation(StorageOperation):
|
187
|
+
def __init__(
|
188
|
+
self,
|
189
|
+
request_id: str,
|
190
|
+
host_indices: torch.Tensor,
|
191
|
+
token_ids: List[int],
|
192
|
+
last_hash: Optional[str] = None,
|
193
|
+
):
|
194
|
+
self.request_id = request_id
|
195
|
+
|
196
|
+
self._done_flag = False
|
197
|
+
self._lock = threading.Lock()
|
198
|
+
|
199
|
+
super().__init__(host_indices, token_ids, last_hash)
|
200
|
+
|
201
|
+
def increment(self, num_tokens: int):
|
202
|
+
with self._lock:
|
203
|
+
if self._done_flag:
|
204
|
+
return
|
205
|
+
self.completed_tokens += num_tokens
|
206
|
+
|
207
|
+
def mark_done(self):
|
208
|
+
with self._lock:
|
209
|
+
self._done_flag = True
|
210
|
+
|
211
|
+
def is_done(self) -> bool:
|
212
|
+
return self._done_flag
|
213
|
+
|
214
|
+
|
162
215
|
class HiCacheController:
|
163
216
|
|
164
217
|
def __init__(
|
@@ -166,9 +219,12 @@ class HiCacheController:
|
|
166
219
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
167
220
|
mem_pool_host: HostKVCache,
|
168
221
|
page_size: int,
|
222
|
+
tp_group: torch.distributed.ProcessGroup,
|
169
223
|
load_cache_event: threading.Event = None,
|
170
224
|
write_policy: str = "write_through_selective",
|
171
225
|
io_backend: str = "",
|
226
|
+
storage_backend: Optional[str] = None,
|
227
|
+
prefetch_threshold: int = 256,
|
172
228
|
):
|
173
229
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
174
230
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
@@ -186,6 +242,25 @@ class HiCacheController:
|
|
186
242
|
else:
|
187
243
|
self.io_backend = io_backend
|
188
244
|
|
245
|
+
self.enable_storage = False
|
246
|
+
# todo: move backend initialization to storage backend module
|
247
|
+
if storage_backend is not None:
|
248
|
+
# create a new communication group for synchronizing storage operations across TP workers
|
249
|
+
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
250
|
+
if self.tp_world_size > 1:
|
251
|
+
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
252
|
+
self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
253
|
+
|
254
|
+
if storage_backend == "file":
|
255
|
+
self.storage_backend = HiCacheFile()
|
256
|
+
self.enable_storage = True
|
257
|
+
# todo: threshold policy for prefetching
|
258
|
+
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
259
|
+
else:
|
260
|
+
raise NotImplementedError(
|
261
|
+
f"Unsupported storage backend: {storage_backend}"
|
262
|
+
)
|
263
|
+
|
189
264
|
self.load_cache_event = load_cache_event
|
190
265
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
191
266
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
@@ -218,9 +293,26 @@ class HiCacheController:
|
|
218
293
|
self.load_thread = threading.Thread(
|
219
294
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
220
295
|
)
|
296
|
+
|
221
297
|
self.write_thread.start()
|
222
298
|
self.load_thread.start()
|
223
299
|
|
300
|
+
if self.enable_storage:
|
301
|
+
self.prefetch_thread = threading.Thread(
|
302
|
+
target=self.prefetch_thread_func, daemon=True
|
303
|
+
)
|
304
|
+
self.backup_thread = threading.Thread(
|
305
|
+
target=self.backup_thread_func, daemon=True
|
306
|
+
)
|
307
|
+
self.prefetch_queue = Queue()
|
308
|
+
self.backup_queue = Queue()
|
309
|
+
|
310
|
+
self.prefetch_revoke_queue = Queue()
|
311
|
+
self.ack_backup_queue = Queue()
|
312
|
+
|
313
|
+
self.prefetch_thread.start()
|
314
|
+
self.backup_thread.start()
|
315
|
+
|
224
316
|
def reset(self):
|
225
317
|
self.stop_event.set()
|
226
318
|
self.write_thread.join()
|
@@ -232,6 +324,13 @@ class HiCacheController:
|
|
232
324
|
self.load_buffer.clear()
|
233
325
|
self.ack_write_queue.queue.clear()
|
234
326
|
self.ack_load_queue.queue.clear()
|
327
|
+
if self.enable_storage:
|
328
|
+
self.prefetch_thread.join()
|
329
|
+
self.backup_thread.join()
|
330
|
+
self.prefetch_queue.queue.clear()
|
331
|
+
self.backup_queue.queue.clear()
|
332
|
+
self.prefetch_revoke_queue.queue.clear()
|
333
|
+
self.ack_backup_queue.queue.clear()
|
235
334
|
|
236
335
|
self.write_thread = threading.Thread(
|
237
336
|
target=self.write_thread_func_direct, daemon=True
|
@@ -243,6 +342,16 @@ class HiCacheController:
|
|
243
342
|
self.write_thread.start()
|
244
343
|
self.load_thread.start()
|
245
344
|
|
345
|
+
if self.enable_storage:
|
346
|
+
self.prefetch_thread = threading.Thread(
|
347
|
+
target=self.prefetch_thread_func, daemon=True
|
348
|
+
)
|
349
|
+
self.backup_thread = threading.Thread(
|
350
|
+
target=self.backup_thread_func, daemon=True
|
351
|
+
)
|
352
|
+
self.prefetch_thread.start()
|
353
|
+
self.backup_thread.start()
|
354
|
+
|
246
355
|
def write(
|
247
356
|
self,
|
248
357
|
device_indices: torch.Tensor,
|
@@ -256,6 +365,7 @@ class HiCacheController:
|
|
256
365
|
if host_indices is None:
|
257
366
|
return None
|
258
367
|
self.mem_pool_host.protect_write(host_indices)
|
368
|
+
torch.cuda.current_stream().synchronize()
|
259
369
|
self.write_queue.put(
|
260
370
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
261
371
|
)
|
@@ -383,3 +493,181 @@ class HiCacheController:
|
|
383
493
|
raise ValueError(
|
384
494
|
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
385
495
|
)
|
496
|
+
|
497
|
+
def prefetch(
|
498
|
+
self,
|
499
|
+
request_id: str,
|
500
|
+
host_indices: torch.Tensor,
|
501
|
+
new_input_tokens: List[int],
|
502
|
+
last_hash: Optional[str] = None,
|
503
|
+
) -> int:
|
504
|
+
"""
|
505
|
+
Prefetch KV caches from storage backend to host memory.
|
506
|
+
"""
|
507
|
+
operation = PrefetchOperation(
|
508
|
+
request_id, host_indices, new_input_tokens, last_hash
|
509
|
+
)
|
510
|
+
self.prefetch_queue.put(operation)
|
511
|
+
return operation
|
512
|
+
|
513
|
+
def terminate_prefetch(self, operation):
|
514
|
+
operation.mark_done()
|
515
|
+
return operation.completed_tokens, operation.hash_value
|
516
|
+
|
517
|
+
def prefetch_io_aux_func(self):
|
518
|
+
"""
|
519
|
+
Auxiliary function conducting IO operations for prefetching.
|
520
|
+
"""
|
521
|
+
while not self.stop_event.is_set():
|
522
|
+
try:
|
523
|
+
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
524
|
+
for h in operation.hash_value:
|
525
|
+
page_data = self.storage_backend.get(h)
|
526
|
+
if page_data is None:
|
527
|
+
logger.warning(
|
528
|
+
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
529
|
+
)
|
530
|
+
break
|
531
|
+
self.mem_pool_host.set_from_flat_data_page(
|
532
|
+
operation.host_indices[operation.completed_tokens],
|
533
|
+
page_data,
|
534
|
+
)
|
535
|
+
operation.increment(self.page_size)
|
536
|
+
if operation.is_done():
|
537
|
+
# operation terminated by controller, release pre-allocated memory
|
538
|
+
self.mem_pool_host.free(
|
539
|
+
operation.host_indices[operation.completed_tokens :]
|
540
|
+
)
|
541
|
+
break
|
542
|
+
except Empty:
|
543
|
+
continue
|
544
|
+
|
545
|
+
def prefetch_thread_func(self):
|
546
|
+
"""
|
547
|
+
Manage prefetching operations from storage backend to host memory.
|
548
|
+
"""
|
549
|
+
self.prefetch_buffer = Queue()
|
550
|
+
aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
|
551
|
+
aux_thread.start()
|
552
|
+
while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
|
553
|
+
try:
|
554
|
+
operation = self.prefetch_queue.get(block=True, timeout=1)
|
555
|
+
if operation is None:
|
556
|
+
continue
|
557
|
+
|
558
|
+
last_hash = operation.last_hash
|
559
|
+
tokens_to_fetch = operation.token_ids
|
560
|
+
|
561
|
+
storage_hit_count = 0
|
562
|
+
remaining_tokens = len(tokens_to_fetch)
|
563
|
+
hash_value = []
|
564
|
+
while remaining_tokens >= self.page_size:
|
565
|
+
last_hash = get_hash_str(
|
566
|
+
tokens_to_fetch[
|
567
|
+
storage_hit_count : storage_hit_count + self.page_size
|
568
|
+
],
|
569
|
+
last_hash,
|
570
|
+
)
|
571
|
+
if self.storage_backend.exists(last_hash):
|
572
|
+
storage_hit_count += self.page_size
|
573
|
+
hash_value.append(last_hash)
|
574
|
+
remaining_tokens -= self.page_size
|
575
|
+
else:
|
576
|
+
break
|
577
|
+
|
578
|
+
if self.tp_world_size > 1:
|
579
|
+
storage_hit_count_tensor = torch.tensor(
|
580
|
+
storage_hit_count, dtype=torch.int
|
581
|
+
)
|
582
|
+
torch.distributed.all_reduce(
|
583
|
+
storage_hit_count_tensor,
|
584
|
+
op=torch.distributed.ReduceOp.MIN,
|
585
|
+
group=self.tp_group,
|
586
|
+
)
|
587
|
+
storage_hit_count = storage_hit_count_tensor.item()
|
588
|
+
|
589
|
+
if storage_hit_count < self.prefetch_threshold:
|
590
|
+
# not to prefetch if not enough benefits
|
591
|
+
self.prefetch_revoke_queue.put(operation.request_id)
|
592
|
+
logger.debug(
|
593
|
+
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
594
|
+
)
|
595
|
+
else:
|
596
|
+
operation.hash_value = hash_value[
|
597
|
+
: (storage_hit_count // self.page_size)
|
598
|
+
]
|
599
|
+
# free the pre-allocated memory for pages that are not hit
|
600
|
+
self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
|
601
|
+
operation.host_indices = operation.host_indices[:storage_hit_count]
|
602
|
+
logger.debug(
|
603
|
+
f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
|
604
|
+
)
|
605
|
+
self.prefetch_buffer.put(operation)
|
606
|
+
|
607
|
+
except Empty:
|
608
|
+
continue
|
609
|
+
|
610
|
+
def write_storage(
|
611
|
+
self,
|
612
|
+
host_indices: torch.Tensor,
|
613
|
+
token_ids: List[int],
|
614
|
+
last_hash: Optional[str] = None,
|
615
|
+
) -> int:
|
616
|
+
"""
|
617
|
+
Write KV caches from host memory to storage backend.
|
618
|
+
"""
|
619
|
+
operation = StorageOperation(host_indices, token_ids, last_hash)
|
620
|
+
self.backup_queue.put(operation)
|
621
|
+
return operation.id
|
622
|
+
|
623
|
+
def backup_thread_func(self):
|
624
|
+
"""
|
625
|
+
Manage backup operations from host memory to storage backend.
|
626
|
+
"""
|
627
|
+
while not self.stop_event.is_set():
|
628
|
+
try:
|
629
|
+
operation = self.backup_queue.get(block=True, timeout=1)
|
630
|
+
if operation is None:
|
631
|
+
continue
|
632
|
+
|
633
|
+
last_hash = operation.last_hash
|
634
|
+
tokens_to_backup = operation.token_ids
|
635
|
+
|
636
|
+
for i in range(0, len(tokens_to_backup), self.page_size):
|
637
|
+
last_hash = get_hash_str(
|
638
|
+
tokens_to_backup[i : i + self.page_size], last_hash
|
639
|
+
)
|
640
|
+
success = self.storage_backend.set(
|
641
|
+
last_hash,
|
642
|
+
self.mem_pool_host.get_flat_data_page(
|
643
|
+
operation.host_indices[i]
|
644
|
+
),
|
645
|
+
)
|
646
|
+
if not success:
|
647
|
+
logger.warning(f"Failed to write page {last_hash} to storage.")
|
648
|
+
break
|
649
|
+
operation.completed_tokens += self.page_size
|
650
|
+
operation.hash_value.append(last_hash)
|
651
|
+
|
652
|
+
min_completed_tokens = operation.completed_tokens
|
653
|
+
if self.tp_world_size > 1:
|
654
|
+
completed_tokens_tensor = torch.tensor(
|
655
|
+
min_completed_tokens, dtype=torch.int
|
656
|
+
)
|
657
|
+
torch.distributed.all_reduce(
|
658
|
+
completed_tokens_tensor,
|
659
|
+
op=torch.distributed.ReduceOp.MIN,
|
660
|
+
group=self.tp_group,
|
661
|
+
)
|
662
|
+
min_completed_tokens = completed_tokens_tensor.item()
|
663
|
+
|
664
|
+
self.ack_backup_queue.put(
|
665
|
+
(
|
666
|
+
operation.id,
|
667
|
+
operation.hash_value[: min_completed_tokens // self.page_size],
|
668
|
+
min_completed_tokens,
|
669
|
+
)
|
670
|
+
)
|
671
|
+
|
672
|
+
except Empty:
|
673
|
+
continue
|