sglang 0.4.2.post2__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/check_env.py +1 -0
- sglang/srt/constrained/outlines_backend.py +4 -1
- sglang/srt/function_call_parser.py +96 -69
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
- sglang/srt/layers/attention/flashinfer_backend.py +34 -41
- sglang/srt/layers/attention/triton_backend.py +64 -16
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -5
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8_kernel.py +43 -10
- sglang/srt/lora/backend/__init__.py +25 -5
- sglang/srt/lora/backend/base_backend.py +31 -9
- sglang/srt/lora/backend/flashinfer_backend.py +41 -4
- sglang/srt/lora/backend/triton_backend.py +34 -4
- sglang/srt/lora/layers.py +293 -0
- sglang/srt/lora/lora.py +101 -326
- sglang/srt/lora/lora_manager.py +101 -269
- sglang/srt/lora/mem_pool.py +174 -0
- sglang/srt/lora/triton_ops/__init__.py +7 -1
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
- sglang/srt/lora/utils.py +141 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -0
- sglang/srt/models/llama.py +8 -3
- sglang/srt/speculative/build_eagle_tree.py +482 -102
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +134 -61
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +49 -32
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora_manager.py
CHANGED
@@ -16,307 +16,115 @@
|
|
16
16
|
# and "Punica: Multi-Tenant LoRA Serving"
|
17
17
|
|
18
18
|
import logging
|
19
|
-
import
|
19
|
+
from typing import Dict, List, Set, Tuple
|
20
20
|
|
21
21
|
import torch
|
22
22
|
|
23
|
-
from sglang.srt.
|
24
|
-
from sglang.srt.
|
23
|
+
from sglang.srt.configs.load_config import LoadConfig
|
24
|
+
from sglang.srt.hf_transformers_utils import AutoConfig
|
25
|
+
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
|
26
|
+
from sglang.srt.lora.layers import get_lora_layer
|
27
|
+
from sglang.srt.lora.lora import LoRAAdapter
|
25
28
|
from sglang.srt.lora.lora_config import LoRAConfig
|
29
|
+
from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
30
|
+
from sglang.srt.lora.utils import (
|
31
|
+
LoRABatchInfo,
|
32
|
+
LoRAType,
|
33
|
+
get_customized_names_from_hf_names,
|
34
|
+
get_layer_id,
|
35
|
+
get_stacked_name,
|
36
|
+
get_weight_name,
|
37
|
+
)
|
26
38
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
27
|
-
from sglang.srt.utils import
|
39
|
+
from sglang.srt.utils import replace_submodule
|
28
40
|
|
29
41
|
logger = logging.getLogger(__name__)
|
30
42
|
|
31
43
|
|
32
|
-
def get_module_name(name):
|
33
|
-
# Fallback solution of mapping from config module name to module name in model class.
|
34
|
-
# Please check if it aligns with your base model.
|
35
|
-
# Please implement the function in the model class if it is not.
|
36
|
-
# You can reference this function in llama.py.
|
37
|
-
params_mapping = {
|
38
|
-
"q_proj": "qkv_proj",
|
39
|
-
"k_proj": "qkv_proj",
|
40
|
-
"v_proj": "qkv_proj",
|
41
|
-
"gate_proj": "gate_up_proj",
|
42
|
-
"up_proj": "gate_up_proj",
|
43
|
-
}
|
44
|
-
return params_mapping.get(name, name)
|
45
|
-
|
46
|
-
|
47
|
-
def get_hidden_dim(module_name, config):
|
48
|
-
# Fallback solution of get_hidden_dim for different modules
|
49
|
-
# Please check if it aligns with your base model.
|
50
|
-
# Please implement the function in the model class if it is not.
|
51
|
-
# You can reference this function in llama.py.
|
52
|
-
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
53
|
-
return config.hidden_size, config.hidden_size
|
54
|
-
elif module_name in ["kv_proj"]:
|
55
|
-
return config.hidden_size, config.hidden_size // (
|
56
|
-
config.num_attention_heads // config.num_key_value_heads
|
57
|
-
)
|
58
|
-
elif module_name == "gate_up_proj":
|
59
|
-
return config.hidden_size, config.intermediate_size
|
60
|
-
elif module_name == "down_proj":
|
61
|
-
return config.intermediate_size, config.hidden_size
|
62
|
-
else:
|
63
|
-
raise NotImplementedError()
|
64
|
-
|
65
|
-
|
66
|
-
def get_stacked_name(name):
|
67
|
-
# origin name -> (name for A, name for B)
|
68
|
-
params_mapping = {
|
69
|
-
"q_proj": ("qkv_proj", "q_proj"),
|
70
|
-
"k_proj": ("qkv_proj", "kv_proj"),
|
71
|
-
"v_proj": ("qkv_proj", "kv_proj"),
|
72
|
-
"gate_proj": ("gate_up_proj", "gate_up_proj"),
|
73
|
-
"up_proj": ("gate_up_proj", "gate_up_proj"),
|
74
|
-
}
|
75
|
-
return params_mapping.get(name, (name, name))
|
76
|
-
|
77
|
-
|
78
|
-
def get_backend_from_name(name):
|
79
|
-
backend_mapping = {
|
80
|
-
"triton": TritonLoraBackend,
|
81
|
-
"flashinfer": FlashInferLoraBackend,
|
82
|
-
}
|
83
|
-
|
84
|
-
if name in backend_mapping:
|
85
|
-
return backend_mapping[name]
|
86
|
-
|
87
|
-
raise Exception(
|
88
|
-
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
|
89
|
-
)
|
90
|
-
|
91
|
-
|
92
|
-
def get_layer_id(name):
|
93
|
-
match = re.search(r"layers\.(\d+)\.", name)
|
94
|
-
if match is None:
|
95
|
-
return None
|
96
|
-
return int(match.group(1))
|
97
|
-
|
98
|
-
|
99
44
|
class LoRAManager:
|
100
45
|
def __init__(
|
101
46
|
self,
|
102
|
-
base_model,
|
103
|
-
lora_paths,
|
104
|
-
base_hf_config,
|
105
|
-
max_loras_per_batch,
|
106
|
-
load_config,
|
107
|
-
dtype,
|
108
|
-
lora_backend,
|
47
|
+
base_model: torch.nn.Module,
|
48
|
+
lora_paths: Dict[str, str],
|
49
|
+
base_hf_config: AutoConfig,
|
50
|
+
max_loras_per_batch: int,
|
51
|
+
load_config: LoadConfig,
|
52
|
+
dtype: torch.dtype,
|
53
|
+
lora_backend: str = "triton",
|
109
54
|
):
|
110
|
-
self.base_model = base_model
|
111
|
-
self.lora_paths = lora_paths
|
112
|
-
self.base_hf_config = base_hf_config
|
113
|
-
self.max_loras_per_batch = max_loras_per_batch
|
114
|
-
self.load_config = load_config
|
115
|
-
self.dtype = dtype
|
116
|
-
|
117
|
-
|
55
|
+
self.base_model: torch.nn.Module = base_model
|
56
|
+
self.lora_paths: Dict[str, str] = lora_paths
|
57
|
+
self.base_hf_config: AutoConfig = base_hf_config
|
58
|
+
self.max_loras_per_batch: int = max_loras_per_batch
|
59
|
+
self.load_config: LoadConfig = load_config
|
60
|
+
self.dtype: torch.dtype = dtype
|
61
|
+
|
62
|
+
# LoRA backend for running sgemm kernels
|
63
|
+
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
118
64
|
backend_type = get_backend_from_name(lora_backend)
|
119
|
-
self.lora_backend = backend_type(lora_backend)
|
65
|
+
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
120
66
|
|
121
67
|
self.init_loras()
|
122
68
|
self.init_lora_memory_pool()
|
123
|
-
self.init_lora_batch()
|
124
|
-
|
125
|
-
def match_target_modules(self, module_name):
|
126
|
-
for target_module in self.target_modules:
|
127
|
-
if module_name.split(".")[-1] == target_module:
|
128
|
-
return True
|
129
|
-
return False
|
130
|
-
|
131
|
-
def get_target_modules(self):
|
132
|
-
modules = []
|
133
|
-
for module_name, module in self.base_model.named_modules():
|
134
|
-
if self.match_target_modules(module_name):
|
135
|
-
modules.append((module_name, module))
|
136
|
-
return modules
|
137
|
-
|
138
|
-
def set_lora_module(self, module_name, module):
|
139
|
-
lora_module = get_lora_layer(
|
140
|
-
module, self.max_lora_dim, self.scaling, self.lora_backend
|
141
|
-
)
|
142
|
-
replace_submodule(self.base_model, module_name, lora_module)
|
143
|
-
return lora_module
|
144
69
|
|
145
70
|
def init_loras(self):
|
146
|
-
#
|
147
|
-
self.configs = {}
|
148
|
-
|
71
|
+
# Config of each LoRA adapter
|
72
|
+
self.configs: Dict[str, LoRAConfig] = {}
|
73
|
+
|
74
|
+
# Target module names in huggingface lora configs.
|
75
|
+
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
76
|
+
self.hf_target_names: Set[str] = set()
|
149
77
|
for name, path in self.lora_paths.items():
|
150
78
|
self.configs[name] = LoRAConfig(path)
|
151
|
-
self.
|
79
|
+
self.hf_target_names = set(self.hf_target_names) | set(
|
152
80
|
self.configs[name].target_modules
|
153
81
|
)
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
else:
|
160
|
-
logger.warning(
|
161
|
-
"WARNING: get_module_name() is not defined, "
|
162
|
-
"which is used to map config module name to model implementation module name."
|
163
|
-
"Use the default one, but please check if it is correct for your model."
|
164
|
-
)
|
165
|
-
self.target_modules = {
|
166
|
-
get_module_name(module) for module in self.origin_target_modules
|
167
|
-
}
|
168
|
-
self.target_weights = set(
|
169
|
-
[get_stacked_name(module) for module in self.origin_target_modules]
|
82
|
+
|
83
|
+
# Target lora weight names for lora_a and lora_b modules repectively.
|
84
|
+
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
|
85
|
+
self.lora_weight_names: Set[Tuple[str]] = set(
|
86
|
+
[get_stacked_name(module) for module in self.hf_target_names]
|
170
87
|
)
|
171
88
|
|
172
89
|
# load all weights to cpu
|
173
|
-
self.loras =
|
174
|
-
self.lora_id = {}
|
90
|
+
self.loras: Dict[str, LoRAAdapter] = {}
|
175
91
|
for name in self.lora_paths.keys():
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
self.load_config,
|
183
|
-
self.lora_backend,
|
184
|
-
)
|
92
|
+
lora_adapter = LoRAAdapter(
|
93
|
+
name,
|
94
|
+
self.configs[name],
|
95
|
+
self.base_hf_config,
|
96
|
+
self.load_config,
|
97
|
+
self.lora_backend,
|
185
98
|
)
|
186
|
-
|
99
|
+
lora_adapter.initialize_weights()
|
100
|
+
self.loras[name] = lora_adapter
|
187
101
|
|
188
102
|
# misc lora configs
|
189
|
-
|
190
|
-
self.
|
191
|
-
|
103
|
+
# FIXME remove the restrictions after implementing unified paging
|
104
|
+
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
105
|
+
self.scaling: float = list(self.loras.values())[0].scaling
|
192
106
|
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
|
193
|
-
assert all(x.scaling == self.scaling for x in self.loras)
|
107
|
+
assert all(x.scaling == self.scaling for x in self.loras.values())
|
194
108
|
|
195
|
-
#
|
196
|
-
self.
|
197
|
-
for module_name, module in self.get_target_modules():
|
198
|
-
self.lora_modules.append(
|
199
|
-
(module_name, self.set_lora_module(module_name, module))
|
200
|
-
)
|
109
|
+
# Convert original model layers to layers with LoRA
|
110
|
+
self.convert_to_lora_layers()
|
201
111
|
|
202
112
|
def init_lora_memory_pool(self):
|
203
|
-
#
|
204
|
-
self.
|
205
|
-
|
206
|
-
|
207
|
-
for module_A, module_B in self.target_weights:
|
208
|
-
# init A tensor, column_major=True
|
209
|
-
if hasattr(self.base_model, "get_hidden_dim"):
|
210
|
-
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
|
211
|
-
else:
|
212
|
-
logger.warning(
|
213
|
-
"WARNING: get_hidden_dim() is not defined, "
|
214
|
-
"which is used to get the hidden dim for different lora modules"
|
215
|
-
"Use the default one, but please check if it is correct for your model."
|
216
|
-
)
|
217
|
-
hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
|
218
|
-
c = self.loras[-1].get_stacked_multiply(module_A)
|
219
|
-
if module_A not in self.A_buffer:
|
220
|
-
self.A_buffer[module_A] = [
|
221
|
-
torch.empty(
|
222
|
-
(
|
223
|
-
self.max_loras_per_batch,
|
224
|
-
self.max_lora_dim * c,
|
225
|
-
hidden_dim_A,
|
226
|
-
),
|
227
|
-
dtype=self.dtype,
|
228
|
-
device="cuda",
|
229
|
-
)
|
230
|
-
for i in range(num_layer)
|
231
|
-
]
|
232
|
-
# init B tensor, column_major=True
|
233
|
-
if hasattr(self.base_model, "get_hidden_dim"):
|
234
|
-
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
|
235
|
-
else:
|
236
|
-
logger.warning(
|
237
|
-
"WARNING: get_hidden_dim() is not defined, "
|
238
|
-
"which is used to get the hidden dim for different lora modules"
|
239
|
-
"Use the default one, but please check if it is correct for your model."
|
240
|
-
)
|
241
|
-
_, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
|
242
|
-
c = self.loras[-1].get_stacked_multiply(module_B)
|
243
|
-
if module_B not in self.B_buffer:
|
244
|
-
self.B_buffer[module_B] = [
|
245
|
-
torch.empty(
|
246
|
-
(
|
247
|
-
c,
|
248
|
-
self.max_loras_per_batch,
|
249
|
-
hidden_dim_B,
|
250
|
-
self.max_lora_dim,
|
251
|
-
),
|
252
|
-
dtype=self.dtype,
|
253
|
-
device="cuda",
|
254
|
-
)
|
255
|
-
for i in range(num_layer)
|
256
|
-
]
|
257
|
-
|
258
|
-
def init_lora_batch(self):
|
259
|
-
self.active_uids = set() # set of active loras
|
260
|
-
self.buffer_id = {} # lora uid -> idx in memory pool
|
261
|
-
|
262
|
-
def get_weight_name(self, name, idx):
|
263
|
-
for target_weight_name in self.target_weights:
|
264
|
-
if target_weight_name[idx] in name:
|
265
|
-
return target_weight_name[idx]
|
266
|
-
|
267
|
-
def load_lora(self, uid, buffer_id):
|
268
|
-
num_layer = self.base_hf_config.num_hidden_layers
|
269
|
-
if uid is None:
|
270
|
-
for i in range(num_layer):
|
271
|
-
for k in self.A_buffer.keys():
|
272
|
-
self.A_buffer[k][i][buffer_id] *= 0
|
273
|
-
return
|
113
|
+
# Initialize memory pool
|
114
|
+
self.memory_pool = LoRAMemoryPool(
|
115
|
+
self.base_hf_config, self.max_loras_per_batch, self.max_lora_dim, self.dtype
|
116
|
+
)
|
274
117
|
|
275
|
-
|
276
|
-
|
277
|
-
for name, weights in layer_weights.items():
|
278
|
-
if "lora_A" in name:
|
279
|
-
lora_weight_name = self.get_weight_name(name, 0)
|
280
|
-
if lora_weight_name:
|
281
|
-
self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
282
|
-
else:
|
283
|
-
lora_weight_name = self.get_weight_name(name, 1)
|
284
|
-
if lora_weight_name:
|
285
|
-
c = self.loras[-1].get_stacked_multiply(lora_weight_name)
|
286
|
-
if c > 1:
|
287
|
-
for j in range(c):
|
288
|
-
self.B_buffer[lora_weight_name][i][j][buffer_id].copy_(
|
289
|
-
weights[j]
|
290
|
-
)
|
291
|
-
else:
|
292
|
-
self.B_buffer[lora_weight_name][i][0][buffer_id].copy_(
|
293
|
-
weights
|
294
|
-
)
|
118
|
+
# Initialize target lora modules in memory pool
|
119
|
+
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
|
295
120
|
|
296
121
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
297
122
|
# load active loras into lora memory pool
|
298
123
|
cur_uids = set(forward_batch.lora_paths)
|
299
124
|
assert len(cur_uids) <= self.max_loras_per_batch
|
300
|
-
|
301
|
-
j = len(self.active_uids)
|
302
|
-
evictable_uids = list(self.active_uids)
|
303
|
-
for uid in cur_uids:
|
304
|
-
if uid not in self.active_uids:
|
305
|
-
if j < self.max_loras_per_batch:
|
306
|
-
index = j
|
307
|
-
j += 1
|
308
|
-
else:
|
309
|
-
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
|
310
|
-
i += 1
|
311
|
-
assert i < len(evictable_uids)
|
312
|
-
self.active_uids.remove(evictable_uids[i])
|
313
|
-
self.buffer_id.pop(evictable_uids[i])
|
314
|
-
index = i
|
315
|
-
i += 1
|
316
|
-
self.load_lora(uid, index)
|
317
|
-
self.active_uids.add(uid)
|
318
|
-
self.buffer_id[uid] = index
|
125
|
+
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
319
126
|
|
127
|
+
# FIXME: Handle lora uid with None more safely
|
320
128
|
if cur_uids == set([None]):
|
321
129
|
return
|
322
130
|
|
@@ -332,9 +140,9 @@ class LoRAManager:
|
|
332
140
|
max_len = int(torch.max(seg_lens))
|
333
141
|
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
334
142
|
for i, lora_path in enumerate(forward_batch.lora_paths):
|
335
|
-
weight_indices[i] = self.
|
143
|
+
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
336
144
|
|
337
|
-
batch_info =
|
145
|
+
batch_info = LoRABatchInfo(
|
338
146
|
bs=bs,
|
339
147
|
seg_lens=seg_lens,
|
340
148
|
seg_indptr=seg_indptr,
|
@@ -346,16 +154,40 @@ class LoRAManager:
|
|
346
154
|
# call set_lora_info for each lora modules
|
347
155
|
for module_name, module in self.lora_modules:
|
348
156
|
layer_id = get_layer_id(module_name)
|
349
|
-
|
350
157
|
if "qkv_proj" not in module_name:
|
351
|
-
weight_name =
|
158
|
+
weight_name = get_weight_name(
|
159
|
+
module_name, self.lora_weight_names, LoRAType.LORA_A
|
160
|
+
)
|
352
161
|
module.set_lora_info(
|
353
|
-
self.
|
354
|
-
self.
|
162
|
+
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
|
163
|
+
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
|
355
164
|
)
|
356
165
|
else:
|
357
166
|
module.set_lora_info(
|
358
|
-
self.
|
359
|
-
self.
|
360
|
-
self.
|
167
|
+
self.memory_pool.get_tensor("qkv_proj", layer_id, LoRAType.LORA_A),
|
168
|
+
self.memory_pool.get_tensor("q_proj", layer_id, LoRAType.LORA_B),
|
169
|
+
self.memory_pool.get_tensor("kv_proj", layer_id, LoRAType.LORA_B),
|
170
|
+
)
|
171
|
+
|
172
|
+
def set_lora_module(self, module_name, module):
|
173
|
+
lora_module = get_lora_layer(
|
174
|
+
module, self.max_lora_dim, self.scaling, self.lora_backend
|
175
|
+
)
|
176
|
+
replace_submodule(self.base_model, module_name, lora_module)
|
177
|
+
return lora_module
|
178
|
+
|
179
|
+
def convert_to_lora_layers(self):
|
180
|
+
# Target module names of customized layers defined in python/sglang/srt/layers
|
181
|
+
# e.g., {"qkv_proj", "o_proj"}
|
182
|
+
customized_target_names = get_customized_names_from_hf_names(
|
183
|
+
self.hf_target_names, self.base_model
|
184
|
+
)
|
185
|
+
|
186
|
+
# Monkey patch to use the LoRA version layers
|
187
|
+
self.lora_modules: List[Tuple[str, torch.nn.Module]] = []
|
188
|
+
for module_name, module in self.base_model.named_modules():
|
189
|
+
# The module should be converted if it is included in target_names
|
190
|
+
if module_name.split(".")[-1] in customized_target_names:
|
191
|
+
self.lora_modules.append(
|
192
|
+
(module_name, self.set_lora_module(module_name, module))
|
361
193
|
)
|
@@ -0,0 +1,174 @@
|
|
1
|
+
from typing import Dict, List, Optional, Set, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.hf_transformers_utils import AutoConfig
|
6
|
+
from sglang.srt.lora.lora import LoRAAdapter
|
7
|
+
from sglang.srt.lora.utils import (
|
8
|
+
LoRAType,
|
9
|
+
get_hidden_dim,
|
10
|
+
get_stacked_multiply,
|
11
|
+
get_weight_name,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
class LoRAMemoryPool:
|
16
|
+
"""Class for memory pool management of lora modules"""
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
base_hf_config: AutoConfig,
|
21
|
+
max_loras_per_batch: int,
|
22
|
+
max_lora_dim: int,
|
23
|
+
dtype: torch.dtype,
|
24
|
+
):
|
25
|
+
|
26
|
+
self.base_hf_config: AutoConfig = base_hf_config
|
27
|
+
self.num_layer: int = base_hf_config.num_hidden_layers
|
28
|
+
self.max_loras_per_batch: int = max_loras_per_batch
|
29
|
+
self.max_lora_dim: int = max_lora_dim
|
30
|
+
self.dtype: torch.dtype = dtype
|
31
|
+
|
32
|
+
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
33
|
+
# A_buffer contains num_layer number of row-major tensors with shape
|
34
|
+
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
|
35
|
+
# B_buffer contains num_layer number of column-major tensors with shape
|
36
|
+
# (stacked_num, max_loras_per_batch, output_dim, max_lora_dim)
|
37
|
+
self.A_buffer: Dict[str, List[torch.Tensor]] = {}
|
38
|
+
self.B_buffer: Dict[str, List[torch.Tensor]] = {}
|
39
|
+
|
40
|
+
# Lora uid -> buffer idx in memory pool
|
41
|
+
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
|
42
|
+
|
43
|
+
# Buffer idx -> lora uid in memory pool
|
44
|
+
# All uids are initalized as empty strings for empty buffer slots
|
45
|
+
# Here we don't initalize to None since None is a valid uid
|
46
|
+
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
47
|
+
|
48
|
+
def init_buffers(
|
49
|
+
self,
|
50
|
+
lora_weight_names: Set[Tuple[str]],
|
51
|
+
base_model: torch.nn.Module,
|
52
|
+
):
|
53
|
+
|
54
|
+
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
55
|
+
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
56
|
+
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
|
57
|
+
|
58
|
+
for module_A, module_B in lora_weight_names:
|
59
|
+
# Init A tensor, column_major=False
|
60
|
+
input_dim, _ = get_hidden_dim(module_A, self.base_hf_config, base_model)
|
61
|
+
c = get_stacked_multiply(module_A)
|
62
|
+
if module_A not in self.A_buffer:
|
63
|
+
self.A_buffer[module_A] = [
|
64
|
+
torch.empty(
|
65
|
+
(
|
66
|
+
self.max_loras_per_batch,
|
67
|
+
self.max_lora_dim * c,
|
68
|
+
input_dim,
|
69
|
+
),
|
70
|
+
dtype=self.dtype,
|
71
|
+
device="cuda",
|
72
|
+
)
|
73
|
+
for i in range(self.num_layer)
|
74
|
+
]
|
75
|
+
|
76
|
+
# Init B tensor, column_major=True
|
77
|
+
_, output_dim = get_hidden_dim(module_B, self.base_hf_config, base_model)
|
78
|
+
c = get_stacked_multiply(module_B)
|
79
|
+
if module_B not in self.B_buffer:
|
80
|
+
self.B_buffer[module_B] = [
|
81
|
+
torch.empty(
|
82
|
+
(
|
83
|
+
c, # stacked lora_b modules might need separation
|
84
|
+
self.max_loras_per_batch,
|
85
|
+
output_dim,
|
86
|
+
self.max_lora_dim,
|
87
|
+
),
|
88
|
+
dtype=self.dtype,
|
89
|
+
device="cuda",
|
90
|
+
)
|
91
|
+
for i in range(self.num_layer)
|
92
|
+
]
|
93
|
+
|
94
|
+
def prepare_lora_batch(
|
95
|
+
self,
|
96
|
+
cur_uids: Set[Optional[str]],
|
97
|
+
lora_adapters: Dict[str, LoRAAdapter],
|
98
|
+
):
|
99
|
+
|
100
|
+
def get_available_buffer_slot():
|
101
|
+
for buffer_id in range(self.max_loras_per_batch):
|
102
|
+
# Prioritize empty slots
|
103
|
+
if self.buffer_id_to_uid[buffer_id] == "":
|
104
|
+
return buffer_id, ""
|
105
|
+
|
106
|
+
for buffer_id in range(self.max_loras_per_batch):
|
107
|
+
# Evict unneeded lora
|
108
|
+
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
|
109
|
+
return buffer_id, self.buffer_id_to_uid[buffer_id]
|
110
|
+
|
111
|
+
raise ValueError(
|
112
|
+
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
|
113
|
+
)
|
114
|
+
|
115
|
+
for uid in cur_uids:
|
116
|
+
if uid not in self.uid_to_buffer_id:
|
117
|
+
buffer_id, evicted_lora_uid = get_available_buffer_slot()
|
118
|
+
if evicted_lora_uid != "":
|
119
|
+
self.uid_to_buffer_id.pop(evicted_lora_uid)
|
120
|
+
self.load_lora_weight_to_buffer(
|
121
|
+
uid, buffer_id, lora_adapters.get(uid, None)
|
122
|
+
)
|
123
|
+
self.uid_to_buffer_id[uid] = buffer_id
|
124
|
+
self.buffer_id_to_uid[buffer_id] = uid
|
125
|
+
|
126
|
+
def load_lora_weight_to_buffer(
|
127
|
+
self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
|
128
|
+
):
|
129
|
+
|
130
|
+
if uid is None:
|
131
|
+
for i in range(self.num_layer):
|
132
|
+
for k in self.A_buffer.keys():
|
133
|
+
self.A_buffer[k][i][buffer_id] *= 0
|
134
|
+
return
|
135
|
+
|
136
|
+
assert lora_adapter is not None
|
137
|
+
for layer_id in range(self.num_layer):
|
138
|
+
layer_weights = lora_adapter.layers[layer_id].weights
|
139
|
+
for name, weights in layer_weights.items():
|
140
|
+
if "lora_A" in name:
|
141
|
+
lora_weight_name = get_weight_name(
|
142
|
+
name, self.lora_weight_names, LoRAType.LORA_A
|
143
|
+
)
|
144
|
+
if lora_weight_name:
|
145
|
+
self.A_buffer[lora_weight_name][layer_id][buffer_id].copy_(
|
146
|
+
weights
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
lora_weight_name = get_weight_name(
|
150
|
+
name, self.lora_weight_names, LoRAType.LORA_B
|
151
|
+
)
|
152
|
+
if lora_weight_name:
|
153
|
+
c = get_stacked_multiply(lora_weight_name)
|
154
|
+
if c > 1:
|
155
|
+
for stacked_id in range(c):
|
156
|
+
self.B_buffer[lora_weight_name][layer_id][stacked_id][
|
157
|
+
buffer_id
|
158
|
+
].copy_(weights[stacked_id])
|
159
|
+
else:
|
160
|
+
self.B_buffer[lora_weight_name][layer_id][0][
|
161
|
+
buffer_id
|
162
|
+
].copy_(weights)
|
163
|
+
|
164
|
+
def get_tensor(
|
165
|
+
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
166
|
+
) -> torch.Tensor:
|
167
|
+
|
168
|
+
if lora_type == LoRAType.LORA_A:
|
169
|
+
return self.A_buffer[weight_name][layer_id]
|
170
|
+
|
171
|
+
return self.B_buffer[weight_name][layer_id]
|
172
|
+
|
173
|
+
def get_buffer_id(self, lora_uid: str):
|
174
|
+
return self.uid_to_buffer_id[lora_uid]
|
@@ -1,5 +1,11 @@
|
|
1
|
+
from .gate_up_lora_b import gate_up_lora_b_fwd
|
1
2
|
from .qkv_lora_b import qkv_lora_b_fwd
|
2
3
|
from .sgemm_lora_a import sgemm_lora_a_fwd
|
3
4
|
from .sgemm_lora_b import sgemm_lora_b_fwd
|
4
5
|
|
5
|
-
__all__ = [
|
6
|
+
__all__ = [
|
7
|
+
"gate_up_lora_b_fwd",
|
8
|
+
"qkv_lora_b_fwd",
|
9
|
+
"sgemm_lora_a_fwd",
|
10
|
+
"sgemm_lora_b_fwd",
|
11
|
+
]
|