sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +6 -25
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +104 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +58 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +117 -131
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +57 -44
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -1
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora.py
ADDED
@@ -0,0 +1,403 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
|
17
|
+
# and "Punica: Multi-Tenant LoRA Serving"
|
18
|
+
|
19
|
+
# LoRA layers class inheritance adapted from:
|
20
|
+
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
21
|
+
|
22
|
+
|
23
|
+
import json
|
24
|
+
import os
|
25
|
+
import re
|
26
|
+
from typing import Any, Dict, List, Optional, Tuple
|
27
|
+
|
28
|
+
import safetensors.torch
|
29
|
+
import torch
|
30
|
+
from torch import nn
|
31
|
+
from vllm.model_executor.layers.linear import (
|
32
|
+
ColumnParallelLinear,
|
33
|
+
MergedColumnParallelLinear,
|
34
|
+
QKVParallelLinear,
|
35
|
+
RowParallelLinear,
|
36
|
+
)
|
37
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
38
|
+
ParallelLMHead,
|
39
|
+
VocabParallelEmbedding,
|
40
|
+
)
|
41
|
+
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
42
|
+
|
43
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
44
|
+
|
45
|
+
|
46
|
+
class BaseLayerWithLoRA(nn.Module):
|
47
|
+
def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
|
48
|
+
super().__init__()
|
49
|
+
self.base_layer = base_layer
|
50
|
+
self.segment_gemm = segment_gemm
|
51
|
+
self.lora_rank = lora_rank
|
52
|
+
self.scaling = scaling
|
53
|
+
self.set_lora = False
|
54
|
+
|
55
|
+
def forward(self, x: torch.Tensor):
|
56
|
+
return self.base_layer.forward(x)
|
57
|
+
|
58
|
+
def set_lora_info(self, *args):
|
59
|
+
pass
|
60
|
+
|
61
|
+
|
62
|
+
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
63
|
+
def __init__(
|
64
|
+
self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
|
65
|
+
) -> None:
|
66
|
+
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
67
|
+
self.weight = base_layer.weight
|
68
|
+
|
69
|
+
|
70
|
+
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
71
|
+
def __init__(
|
72
|
+
self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
|
73
|
+
) -> None:
|
74
|
+
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
75
|
+
|
76
|
+
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
77
|
+
# TODO
|
78
|
+
return output
|
79
|
+
|
80
|
+
def forward(self, input_: torch.Tensor):
|
81
|
+
# duplicate the logic in ColumnParallelLinear
|
82
|
+
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
83
|
+
output_parallel = self.base_layer.quant_method.apply(
|
84
|
+
self.base_layer, input_, bias
|
85
|
+
)
|
86
|
+
|
87
|
+
if self.set_lora:
|
88
|
+
output_parallel = self.apply_lora(output_parallel, input_)
|
89
|
+
|
90
|
+
if self.base_layer.gather_output:
|
91
|
+
output = tensor_model_parallel_all_gather(output_parallel)
|
92
|
+
else:
|
93
|
+
output = output_parallel
|
94
|
+
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
95
|
+
return output, output_bias
|
96
|
+
|
97
|
+
|
98
|
+
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
99
|
+
def __init__(
|
100
|
+
self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
|
101
|
+
) -> None:
|
102
|
+
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
103
|
+
|
104
|
+
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
|
105
|
+
self.set_lora = True
|
106
|
+
self.A_buffer = A_buffer
|
107
|
+
self.B_buffer = B_buffer
|
108
|
+
self.bs = bs
|
109
|
+
self.seq_lens = seq_lens
|
110
|
+
self.weight_indices = weight_indices
|
111
|
+
|
112
|
+
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
113
|
+
lora_a_output = self.segment_gemm.run(
|
114
|
+
x=x,
|
115
|
+
weights=self.A_buffer,
|
116
|
+
batch_size=self.bs,
|
117
|
+
weight_column_major=True,
|
118
|
+
seg_lens=self.seq_lens,
|
119
|
+
weight_indices=self.weight_indices,
|
120
|
+
)
|
121
|
+
# FIXME
|
122
|
+
assert lora_a_output.shape[-1] == self.lora_rank * 2
|
123
|
+
lora_output = torch.empty_like(base_output)
|
124
|
+
output_dim = lora_output.shape[-1] // 2
|
125
|
+
for i in range(2):
|
126
|
+
left = output_dim * i
|
127
|
+
right = left + output_dim
|
128
|
+
lora_output[:, left:right] = self.segment_gemm.run(
|
129
|
+
x=lora_a_output[
|
130
|
+
:, self.lora_rank * i : self.lora_rank * (i + 1)
|
131
|
+
].contiguous(),
|
132
|
+
weights=self.B_buffer[:, left:right, :].contiguous(),
|
133
|
+
batch_size=self.bs,
|
134
|
+
weight_column_major=True,
|
135
|
+
seg_lens=self.seq_lens,
|
136
|
+
weight_indices=self.weight_indices,
|
137
|
+
)
|
138
|
+
return base_output + lora_output * self.scaling
|
139
|
+
|
140
|
+
|
141
|
+
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
142
|
+
def __init__(
|
143
|
+
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
|
144
|
+
) -> None:
|
145
|
+
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
146
|
+
|
147
|
+
def set_lora_info(
|
148
|
+
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices
|
149
|
+
):
|
150
|
+
self.set_lora = True
|
151
|
+
self.A_buffer_qkv = A_buffer_qkv
|
152
|
+
self.B_buffer_q = B_buffer_q
|
153
|
+
self.B_buffer_kv = B_buffer_kv
|
154
|
+
self.bs = bs
|
155
|
+
self.seq_lens = seq_lens
|
156
|
+
self.weight_indices = weight_indices
|
157
|
+
|
158
|
+
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
159
|
+
lora_a_output = self.segment_gemm.run(
|
160
|
+
x=x,
|
161
|
+
weights=self.A_buffer_qkv,
|
162
|
+
batch_size=self.bs,
|
163
|
+
weight_column_major=True,
|
164
|
+
seg_lens=self.seq_lens,
|
165
|
+
weight_indices=self.weight_indices,
|
166
|
+
)
|
167
|
+
# FIXME parallelize qkv
|
168
|
+
lora_output = torch.empty_like(base_output)
|
169
|
+
# q
|
170
|
+
output_dim_q = self.B_buffer_q.shape[-2]
|
171
|
+
lora_output[:, :output_dim_q] = self.segment_gemm.run(
|
172
|
+
x=lora_a_output[:, : self.lora_rank].contiguous(),
|
173
|
+
weights=self.B_buffer_q,
|
174
|
+
batch_size=self.bs,
|
175
|
+
weight_column_major=True,
|
176
|
+
seg_lens=self.seq_lens,
|
177
|
+
weight_indices=self.weight_indices,
|
178
|
+
)
|
179
|
+
# kv
|
180
|
+
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
|
181
|
+
for i in range(2):
|
182
|
+
left = output_dim_kv * i
|
183
|
+
right = left + output_dim_kv
|
184
|
+
lora_output[:, output_dim_q + left : output_dim_q + right] = (
|
185
|
+
self.segment_gemm.run(
|
186
|
+
x=lora_a_output[
|
187
|
+
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
|
188
|
+
].contiguous(),
|
189
|
+
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
|
190
|
+
batch_size=self.bs,
|
191
|
+
weight_column_major=True,
|
192
|
+
seg_lens=self.seq_lens,
|
193
|
+
weight_indices=self.weight_indices,
|
194
|
+
)
|
195
|
+
)
|
196
|
+
return base_output + lora_output * self.scaling
|
197
|
+
|
198
|
+
|
199
|
+
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
200
|
+
def __init__(
|
201
|
+
self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
|
202
|
+
) -> None:
|
203
|
+
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
204
|
+
|
205
|
+
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
|
206
|
+
self.set_lora = True
|
207
|
+
self.A_buffer = A_buffer
|
208
|
+
self.B_buffer = B_buffer
|
209
|
+
self.bs = bs
|
210
|
+
self.seq_lens = seq_lens
|
211
|
+
self.weight_indices = weight_indices
|
212
|
+
|
213
|
+
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
214
|
+
lora_output = self.segment_gemm.run(
|
215
|
+
x=x,
|
216
|
+
weights=self.A_buffer,
|
217
|
+
batch_size=self.bs,
|
218
|
+
weight_column_major=True,
|
219
|
+
seg_lens=self.seq_lens,
|
220
|
+
weight_indices=self.weight_indices,
|
221
|
+
)
|
222
|
+
lora_output = self.segment_gemm.run(
|
223
|
+
x=lora_output,
|
224
|
+
weights=self.B_buffer,
|
225
|
+
batch_size=self.bs,
|
226
|
+
weight_column_major=True,
|
227
|
+
seg_lens=self.seq_lens,
|
228
|
+
weight_indices=self.weight_indices,
|
229
|
+
)
|
230
|
+
return base_output + lora_output * self.scaling
|
231
|
+
|
232
|
+
def forward(self, input_):
|
233
|
+
# duplicate the logic in RowParallelLinear
|
234
|
+
if self.base_layer.input_is_parallel:
|
235
|
+
input_parallel = input_
|
236
|
+
else:
|
237
|
+
tp_rank = get_tensor_model_parallel_rank()
|
238
|
+
splitted_input = split_tensor_along_last_dim(
|
239
|
+
input_, num_partitions=self.base_layer.tp_size
|
240
|
+
)
|
241
|
+
input_parallel = splitted_input[tp_rank].contiguous()
|
242
|
+
output_parallel = self.base_layer.quant_method.apply(
|
243
|
+
self.base_layer, input_parallel
|
244
|
+
)
|
245
|
+
|
246
|
+
if self.set_lora:
|
247
|
+
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
248
|
+
|
249
|
+
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
250
|
+
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
251
|
+
else:
|
252
|
+
output_ = output_parallel
|
253
|
+
|
254
|
+
if not self.base_layer.skip_bias_add:
|
255
|
+
output = (
|
256
|
+
output_ + self.base_layer.bias
|
257
|
+
if self.base_layer.bias is not None
|
258
|
+
else output_
|
259
|
+
)
|
260
|
+
output_bias = None
|
261
|
+
else:
|
262
|
+
output = output_
|
263
|
+
output_bias = self.base_layer.bias
|
264
|
+
return output, output_bias
|
265
|
+
|
266
|
+
|
267
|
+
def get_lora_layer(
|
268
|
+
layer: nn.Module, segment_gemm, lora_rank, scaling
|
269
|
+
) -> BaseLayerWithLoRA:
|
270
|
+
supported_layer_types = {
|
271
|
+
# the order matters
|
272
|
+
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
|
273
|
+
QKVParallelLinear: QKVParallelLinearWithLoRA,
|
274
|
+
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
275
|
+
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
276
|
+
RowParallelLinear: RowParallelLinearWithLoRA,
|
277
|
+
}
|
278
|
+
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
279
|
+
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
280
|
+
ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
|
281
|
+
return ret
|
282
|
+
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
283
|
+
|
284
|
+
|
285
|
+
def get_mapped_params(module_names):
|
286
|
+
ret = set()
|
287
|
+
for module_name in module_names:
|
288
|
+
ret.add(params_mapping(module_name))
|
289
|
+
return list(ret)
|
290
|
+
|
291
|
+
|
292
|
+
class LoRALayer(nn.Module):
|
293
|
+
def __init__(self, config, base_hf_config):
|
294
|
+
super().__init__()
|
295
|
+
self.config = config
|
296
|
+
self.base_hf_config = base_hf_config
|
297
|
+
self.weights = {}
|
298
|
+
self.weight_gpu = {}
|
299
|
+
|
300
|
+
def load_to_gpu(self):
|
301
|
+
for name, weight in self.weights.items():
|
302
|
+
self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
|
303
|
+
|
304
|
+
def offload_from_gpu(self):
|
305
|
+
for name, weight in self.weights.items():
|
306
|
+
self.weight_gpu[name] = None
|
307
|
+
|
308
|
+
|
309
|
+
class LoRAAdapter(nn.Module):
|
310
|
+
def __init__(self, uid, config, base_hf_config, load_config):
|
311
|
+
super().__init__()
|
312
|
+
self.uid = uid
|
313
|
+
self.config = config
|
314
|
+
assert self.config.hf_config["peft_type"].lower() == "lora"
|
315
|
+
self.base_hf_config = base_hf_config
|
316
|
+
self.load_config = load_config
|
317
|
+
self.scaling = self.config.lora_alpha / self.config.r
|
318
|
+
|
319
|
+
self.layers = nn.ModuleList(
|
320
|
+
[
|
321
|
+
LoRALayer(config, base_hf_config)
|
322
|
+
for i in range(base_hf_config.num_hidden_layers)
|
323
|
+
]
|
324
|
+
)
|
325
|
+
|
326
|
+
self.weights = {}
|
327
|
+
self.weights_gpu = {}
|
328
|
+
|
329
|
+
def get_stacked_multiply(self, module_name):
|
330
|
+
stacked_rank = {
|
331
|
+
"qkv_proj": 3,
|
332
|
+
"kv_proj": 2,
|
333
|
+
"gate_up_proj": 2,
|
334
|
+
}
|
335
|
+
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
336
|
+
|
337
|
+
def load_to_gpu(self):
|
338
|
+
for name, weight in self.weights.items():
|
339
|
+
self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
|
340
|
+
for layer in self.layers:
|
341
|
+
layer.load_to_gpu()
|
342
|
+
|
343
|
+
def offload_from_gpu(self):
|
344
|
+
for name, weight in self.weights.items():
|
345
|
+
self.weights_gpu[name] = None
|
346
|
+
for layer in self.layers:
|
347
|
+
layer.offload_from_gpu()
|
348
|
+
|
349
|
+
# initialize the LoRA weights to cpu
|
350
|
+
def initialize_weights(self):
|
351
|
+
model_path = self.config.path
|
352
|
+
loader = DefaultModelLoader(self.load_config)
|
353
|
+
revision = getattr(self.config.hf_config, "revision", None)
|
354
|
+
for name, loaded_weight in loader._get_weights_iterator(
|
355
|
+
model_path, revision=revision, fall_back_to_pt=True
|
356
|
+
):
|
357
|
+
match = re.search(r"layers\.(\d+)\.", name)
|
358
|
+
if match is not None:
|
359
|
+
layer_id = int(match.group(1))
|
360
|
+
self.layers[layer_id].weights[name] = loaded_weight.cpu()
|
361
|
+
else:
|
362
|
+
self.weights[name] = loaded_weight.cpu()
|
363
|
+
|
364
|
+
# stack kv_proj and gate_up_proj
|
365
|
+
for i in range(self.base_hf_config.num_hidden_layers):
|
366
|
+
layer = self.layers[i]
|
367
|
+
weight_names = [name for name, _ in layer.weights.items()]
|
368
|
+
for weight_name in weight_names:
|
369
|
+
if "k_proj" in weight_name:
|
370
|
+
q_name = weight_name.replace("k_proj", "q_proj")
|
371
|
+
v_name = weight_name.replace("k_proj", "v_proj")
|
372
|
+
kv_name = weight_name.replace("k_proj", "kv_proj")
|
373
|
+
qkv_name = weight_name.replace("k_proj", "qkv_proj")
|
374
|
+
if "lora_A" in weight_name:
|
375
|
+
layer.weights[qkv_name] = torch.cat(
|
376
|
+
(
|
377
|
+
layer.weights[q_name],
|
378
|
+
layer.weights[weight_name],
|
379
|
+
layer.weights[v_name],
|
380
|
+
),
|
381
|
+
0,
|
382
|
+
)
|
383
|
+
layer.weights.pop(q_name)
|
384
|
+
layer.weights.pop(weight_name)
|
385
|
+
layer.weights.pop(v_name)
|
386
|
+
else:
|
387
|
+
layer.weights[kv_name] = torch.cat(
|
388
|
+
(
|
389
|
+
layer.weights[weight_name],
|
390
|
+
layer.weights[v_name],
|
391
|
+
),
|
392
|
+
0,
|
393
|
+
)
|
394
|
+
layer.weights.pop(weight_name)
|
395
|
+
layer.weights.pop(v_name)
|
396
|
+
elif "gate_proj" in weight_name:
|
397
|
+
up_name = weight_name.replace("gate_proj", "up_proj")
|
398
|
+
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
399
|
+
layer.weights[gate_up_name] = torch.cat(
|
400
|
+
(layer.weights[weight_name], layer.weights[up_name]), 0
|
401
|
+
)
|
402
|
+
layer.weights.pop(weight_name)
|
403
|
+
layer.weights.pop(up_name)
|
@@ -0,0 +1,43 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
import json
|
17
|
+
import os
|
18
|
+
|
19
|
+
from huggingface_hub import snapshot_download
|
20
|
+
|
21
|
+
|
22
|
+
class LoRAConfig:
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
path: str,
|
26
|
+
) -> None:
|
27
|
+
self.path = path
|
28
|
+
self.hf_config = self.get_lora_config()
|
29
|
+
self.target_modules = self.hf_config["target_modules"]
|
30
|
+
self.r = self.hf_config["r"]
|
31
|
+
self.lora_alpha = self.hf_config["lora_alpha"]
|
32
|
+
|
33
|
+
def get_lora_config(self, dummy=False):
|
34
|
+
if dummy:
|
35
|
+
raise NotImplementedError()
|
36
|
+
else:
|
37
|
+
if not os.path.isdir(self.path):
|
38
|
+
weights_dir = snapshot_download(self.path, allow_patterns=["*.json"])
|
39
|
+
else:
|
40
|
+
weights_dir = self.path
|
41
|
+
config_name = "adapter_config.json"
|
42
|
+
with open(os.path.join(weights_dir, config_name), "r") as f:
|
43
|
+
return json.load(f)
|
@@ -0,0 +1,256 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
|
17
|
+
# and "Punica: Multi-Tenant LoRA Serving"
|
18
|
+
|
19
|
+
|
20
|
+
import re
|
21
|
+
from dataclasses import dataclass
|
22
|
+
|
23
|
+
import torch
|
24
|
+
from flashinfer import SegmentGEMMWrapper
|
25
|
+
|
26
|
+
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
27
|
+
from sglang.srt.lora.lora_config import LoRAConfig
|
28
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
29
|
+
from sglang.srt.utils import replace_submodule
|
30
|
+
|
31
|
+
|
32
|
+
def get_stacked_name(name):
|
33
|
+
# origin name -> (name for A, name for B)
|
34
|
+
params_mapping = {
|
35
|
+
"q_proj": ("qkv_proj", "q_proj"),
|
36
|
+
"k_proj": ("qkv_proj", "kv_proj"),
|
37
|
+
"v_proj": ("qkv_proj", "kv_proj"),
|
38
|
+
"gate_proj": ("gate_up_proj", "gate_up_proj"),
|
39
|
+
"up_proj": ("gate_up_proj", "gate_up_proj"),
|
40
|
+
}
|
41
|
+
return params_mapping.get(name, (name, name))
|
42
|
+
|
43
|
+
|
44
|
+
def get_layer_id(name):
|
45
|
+
match = re.search(r"layers\.(\d+)\.", name)
|
46
|
+
if match is None:
|
47
|
+
return None
|
48
|
+
return int(match.group(1))
|
49
|
+
|
50
|
+
|
51
|
+
class LoRAManager:
|
52
|
+
def __init__(
|
53
|
+
self,
|
54
|
+
base_model,
|
55
|
+
lora_paths,
|
56
|
+
base_hf_config,
|
57
|
+
max_loras_per_batch,
|
58
|
+
load_config,
|
59
|
+
dtype,
|
60
|
+
):
|
61
|
+
self.base_model = base_model
|
62
|
+
self.lora_paths = lora_paths
|
63
|
+
self.base_hf_config = base_hf_config
|
64
|
+
self.max_loras_per_batch = max_loras_per_batch
|
65
|
+
self.load_config = load_config
|
66
|
+
self.dtype = dtype
|
67
|
+
|
68
|
+
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
69
|
+
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
|
70
|
+
|
71
|
+
self.init_loras()
|
72
|
+
self.init_lora_memory_pool()
|
73
|
+
self.init_lora_batch()
|
74
|
+
|
75
|
+
def match_target_modules(self, module_name):
|
76
|
+
for target_module in self.target_modules:
|
77
|
+
if module_name.split(".")[-1] == target_module:
|
78
|
+
return True
|
79
|
+
return False
|
80
|
+
|
81
|
+
def get_target_modules(self):
|
82
|
+
modules = []
|
83
|
+
for module_name, module in self.base_model.named_modules():
|
84
|
+
if self.match_target_modules(module_name):
|
85
|
+
modules.append((module_name, module))
|
86
|
+
return modules
|
87
|
+
|
88
|
+
def set_lora_module(self, module_name, module):
|
89
|
+
lora_module = get_lora_layer(
|
90
|
+
module, self.segment_gemm, self.max_lora_dim, self.scaling
|
91
|
+
)
|
92
|
+
replace_submodule(self.base_model, module_name, lora_module)
|
93
|
+
return lora_module
|
94
|
+
|
95
|
+
def init_loras(self):
|
96
|
+
# get configs and target modules
|
97
|
+
self.configs = {}
|
98
|
+
self.origin_target_modules = set()
|
99
|
+
for path in self.lora_paths:
|
100
|
+
self.configs[path] = LoRAConfig(path)
|
101
|
+
self.origin_target_modules = set(self.origin_target_modules) | set(
|
102
|
+
self.configs[path].target_modules
|
103
|
+
)
|
104
|
+
self.target_modules = set(
|
105
|
+
[
|
106
|
+
self.base_model.get_module_name(module)
|
107
|
+
for module in self.origin_target_modules
|
108
|
+
]
|
109
|
+
)
|
110
|
+
self.target_weights = set(
|
111
|
+
[get_stacked_name(module) for module in self.origin_target_modules]
|
112
|
+
)
|
113
|
+
|
114
|
+
# load all weights to cpu
|
115
|
+
self.loras = []
|
116
|
+
self.lora_id = {}
|
117
|
+
for path in self.lora_paths:
|
118
|
+
self.lora_id[path] = len(self.loras)
|
119
|
+
self.loras.append(
|
120
|
+
LoRAAdapter(
|
121
|
+
path, self.configs[path], self.base_hf_config, self.load_config
|
122
|
+
)
|
123
|
+
)
|
124
|
+
self.loras[-1].initialize_weights()
|
125
|
+
|
126
|
+
# misc lora configs
|
127
|
+
self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
128
|
+
self.scaling = self.loras[0].scaling
|
129
|
+
# FIXME remove the restrictions
|
130
|
+
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
|
131
|
+
assert all(x.scaling == self.scaling for x in self.loras)
|
132
|
+
|
133
|
+
# monkey patch to use the LoRA version
|
134
|
+
self.lora_modules = []
|
135
|
+
for module_name, module in self.get_target_modules():
|
136
|
+
self.lora_modules.append(
|
137
|
+
(module_name, self.set_lora_module(module_name, module))
|
138
|
+
)
|
139
|
+
|
140
|
+
def init_lora_memory_pool(self):
|
141
|
+
# preallocate lora memory pool
|
142
|
+
self.A_buffer = {}
|
143
|
+
self.B_buffer = {}
|
144
|
+
num_layer = self.base_hf_config.num_hidden_layers
|
145
|
+
for module_A, module_B in self.target_weights:
|
146
|
+
# init A tensor, column_major=True
|
147
|
+
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
|
148
|
+
c = self.loras[-1].get_stacked_multiply(module_A)
|
149
|
+
if module_A not in self.A_buffer:
|
150
|
+
self.A_buffer[module_A] = [
|
151
|
+
torch.empty(
|
152
|
+
(
|
153
|
+
self.max_loras_per_batch,
|
154
|
+
self.max_lora_dim * c,
|
155
|
+
hidden_dim_A,
|
156
|
+
),
|
157
|
+
dtype=self.dtype,
|
158
|
+
device="cuda",
|
159
|
+
)
|
160
|
+
for i in range(num_layer)
|
161
|
+
]
|
162
|
+
# init B tensor, column_major=True
|
163
|
+
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
|
164
|
+
c = self.loras[-1].get_stacked_multiply(module_B)
|
165
|
+
if module_B not in self.B_buffer:
|
166
|
+
self.B_buffer[module_B] = [
|
167
|
+
torch.empty(
|
168
|
+
(
|
169
|
+
self.max_loras_per_batch,
|
170
|
+
hidden_dim_B * c,
|
171
|
+
self.max_lora_dim,
|
172
|
+
),
|
173
|
+
dtype=self.dtype,
|
174
|
+
device="cuda",
|
175
|
+
)
|
176
|
+
for i in range(num_layer)
|
177
|
+
]
|
178
|
+
|
179
|
+
def init_lora_batch(self):
|
180
|
+
self.active_uids = set() # set of active loras
|
181
|
+
self.buffer_id = {} # lora uid -> idx in memory pool
|
182
|
+
|
183
|
+
def get_weight_name(self, name, idx):
|
184
|
+
for target_weight_name in self.target_weights:
|
185
|
+
if target_weight_name[idx] in name:
|
186
|
+
return target_weight_name[idx]
|
187
|
+
|
188
|
+
def load_lora(self, uid, buffer_id):
|
189
|
+
num_layer = self.base_hf_config.num_hidden_layers
|
190
|
+
if uid is None:
|
191
|
+
for i in range(num_layer):
|
192
|
+
for k in self.A_buffer.keys():
|
193
|
+
self.A_buffer[k][i][buffer_id] *= 0
|
194
|
+
return
|
195
|
+
|
196
|
+
for i in range(num_layer):
|
197
|
+
layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
|
198
|
+
for name, weights in layer_weights.items():
|
199
|
+
if "lora_A" in name:
|
200
|
+
lora_weight_name = self.get_weight_name(name, 0)
|
201
|
+
if lora_weight_name:
|
202
|
+
self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
203
|
+
else:
|
204
|
+
lora_weight_name = self.get_weight_name(name, 1)
|
205
|
+
if lora_weight_name:
|
206
|
+
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
207
|
+
|
208
|
+
def prepare_lora_batch(self, batch, extend_seq_lens=None):
|
209
|
+
# load active loras into lora memory pool
|
210
|
+
cur_uids = set([req.lora_path for req in batch.reqs])
|
211
|
+
assert len(cur_uids) <= self.max_loras_per_batch
|
212
|
+
i = 0
|
213
|
+
evictable_uids = list(self.active_uids)
|
214
|
+
for uid in cur_uids:
|
215
|
+
if uid not in self.active_uids:
|
216
|
+
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
|
217
|
+
i += 1
|
218
|
+
if i < len(evictable_uids):
|
219
|
+
self.active_uids.remove(evictable_uids[i])
|
220
|
+
self.buffer_id.pop(evictable_uids[i])
|
221
|
+
self.load_lora(uid, i)
|
222
|
+
self.active_uids.add(uid)
|
223
|
+
self.buffer_id[uid] = i
|
224
|
+
i += 1
|
225
|
+
|
226
|
+
if cur_uids == set([None]):
|
227
|
+
return
|
228
|
+
|
229
|
+
# setup lora in forward modules
|
230
|
+
bs = len(batch.reqs)
|
231
|
+
seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
|
232
|
+
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
233
|
+
for i, req in enumerate(batch.reqs):
|
234
|
+
weight_indices[i] = self.buffer_id[req.lora_path]
|
235
|
+
|
236
|
+
for module_name, module in self.lora_modules:
|
237
|
+
layer_id = get_layer_id(module_name)
|
238
|
+
|
239
|
+
if "qkv_proj" not in module_name:
|
240
|
+
weight_name = self.get_weight_name(module_name, 0)
|
241
|
+
module.set_lora_info(
|
242
|
+
self.A_buffer[weight_name][layer_id],
|
243
|
+
self.B_buffer[weight_name][layer_id],
|
244
|
+
bs,
|
245
|
+
seg_lens,
|
246
|
+
weight_indices,
|
247
|
+
)
|
248
|
+
else:
|
249
|
+
module.set_lora_info(
|
250
|
+
self.A_buffer["qkv_proj"][layer_id],
|
251
|
+
self.B_buffer["q_proj"][layer_id],
|
252
|
+
self.B_buffer["kv_proj"][layer_id],
|
253
|
+
bs,
|
254
|
+
seg_lens,
|
255
|
+
weight_indices,
|
256
|
+
)
|