sglang 0.4.2.post2__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/check_env.py +1 -0
- sglang/srt/constrained/outlines_backend.py +4 -1
- sglang/srt/function_call_parser.py +96 -69
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
- sglang/srt/layers/attention/flashinfer_backend.py +34 -41
- sglang/srt/layers/attention/triton_backend.py +64 -16
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -5
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8_kernel.py +43 -10
- sglang/srt/lora/backend/__init__.py +25 -5
- sglang/srt/lora/backend/base_backend.py +31 -9
- sglang/srt/lora/backend/flashinfer_backend.py +41 -4
- sglang/srt/lora/backend/triton_backend.py +34 -4
- sglang/srt/lora/layers.py +293 -0
- sglang/srt/lora/lora.py +101 -326
- sglang/srt/lora/lora_manager.py +101 -269
- sglang/srt/lora/mem_pool.py +174 -0
- sglang/srt/lora/triton_ops/__init__.py +7 -1
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
- sglang/srt/lora/utils.py +141 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -0
- sglang/srt/models/llama.py +8 -3
- sglang/srt/speculative/build_eagle_tree.py +482 -102
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +134 -61
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +49 -32
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora.py
CHANGED
@@ -19,282 +19,25 @@
|
|
19
19
|
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
20
20
|
|
21
21
|
import re
|
22
|
-
from
|
22
|
+
from typing import Dict, List
|
23
23
|
|
24
24
|
import torch
|
25
25
|
from torch import nn
|
26
26
|
|
27
|
-
from sglang.srt.
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
RowParallelLinear,
|
32
|
-
)
|
33
|
-
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
27
|
+
from sglang.srt.configs.load_config import LoadConfig
|
28
|
+
from sglang.srt.hf_transformers_utils import AutoConfig
|
29
|
+
from sglang.srt.lora.backend import BaseLoRABackend
|
30
|
+
from sglang.srt.lora.lora_config import LoRAConfig
|
34
31
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
35
32
|
|
36
33
|
|
37
|
-
@dataclass
|
38
|
-
class LoraBatchInfo:
|
39
|
-
# Batch size
|
40
|
-
bs: int
|
41
|
-
|
42
|
-
# Lengths of each sequence in shape (bs,)
|
43
|
-
seg_lens: torch.Tensor
|
44
|
-
|
45
|
-
# Indice pointers of each sequence in shape (bs + 1, )
|
46
|
-
seg_indptr: torch.Tensor
|
47
|
-
|
48
|
-
# Maximum sequence length of current batch
|
49
|
-
max_len: int
|
50
|
-
|
51
|
-
# The index of lora adapter used by each sequence, in shape (bs,)
|
52
|
-
weight_indices: torch.Tensor
|
53
|
-
|
54
|
-
|
55
|
-
class BaseLayerWithLoRA(nn.Module):
|
56
|
-
def __init__(self, base_layer, lora_rank, scaling, lora_backend):
|
57
|
-
super().__init__()
|
58
|
-
self.base_layer = base_layer
|
59
|
-
self.lora_rank = lora_rank
|
60
|
-
self.scaling = scaling
|
61
|
-
self.set_lora = False
|
62
|
-
self.lora_backend = lora_backend
|
63
|
-
|
64
|
-
def forward(self, x: torch.Tensor):
|
65
|
-
return self.base_layer.forward(x)
|
66
|
-
|
67
|
-
def set_lora_info(self, *args):
|
68
|
-
pass
|
69
|
-
|
70
|
-
|
71
|
-
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
72
|
-
def __init__(
|
73
|
-
self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
|
74
|
-
) -> None:
|
75
|
-
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
76
|
-
self.weight = base_layer.weight
|
77
|
-
|
78
|
-
|
79
|
-
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
80
|
-
def __init__(
|
81
|
-
self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
|
82
|
-
) -> None:
|
83
|
-
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
84
|
-
|
85
|
-
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
86
|
-
# TODO
|
87
|
-
return output
|
88
|
-
|
89
|
-
def forward(self, input_: torch.Tensor):
|
90
|
-
# duplicate the logic in ColumnParallelLinear
|
91
|
-
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
92
|
-
output_parallel = self.base_layer.quant_method.apply(
|
93
|
-
self.base_layer, input_, bias
|
94
|
-
)
|
95
|
-
|
96
|
-
if self.set_lora:
|
97
|
-
output_parallel = self.apply_lora(output_parallel, input_)
|
98
|
-
|
99
|
-
if self.base_layer.gather_output:
|
100
|
-
output = tensor_model_parallel_all_gather(output_parallel)
|
101
|
-
else:
|
102
|
-
output = output_parallel
|
103
|
-
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
104
|
-
return output, output_bias
|
105
|
-
|
106
|
-
|
107
|
-
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
108
|
-
def __init__(
|
109
|
-
self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
|
110
|
-
) -> None:
|
111
|
-
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
112
|
-
|
113
|
-
def set_lora_info(
|
114
|
-
self,
|
115
|
-
A_buffer,
|
116
|
-
B_buffer,
|
117
|
-
):
|
118
|
-
self.set_lora = True
|
119
|
-
self.A_buffer = A_buffer
|
120
|
-
self.B_buffer = B_buffer
|
121
|
-
|
122
|
-
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
123
|
-
lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
|
124
|
-
|
125
|
-
output_dim = base_output.shape[-1]
|
126
|
-
lora_output = torch.empty_like(base_output)
|
127
|
-
lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
|
128
|
-
x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
|
129
|
-
weights=self.B_buffer[0],
|
130
|
-
)
|
131
|
-
|
132
|
-
lora_output[:, output_dim : 2 * output_dim] = (
|
133
|
-
self.lora_backend.run_lora_b_sgemm(
|
134
|
-
x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
|
135
|
-
weights=self.B_buffer[1],
|
136
|
-
)
|
137
|
-
)
|
138
|
-
|
139
|
-
return base_output + lora_output * self.scaling
|
140
|
-
|
141
|
-
|
142
|
-
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
143
|
-
def init__(
|
144
|
-
self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
|
145
|
-
) -> None:
|
146
|
-
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
147
|
-
|
148
|
-
def set_lora_info(
|
149
|
-
self,
|
150
|
-
A_buffer_qkv,
|
151
|
-
B_buffer_q,
|
152
|
-
B_buffer_kv,
|
153
|
-
):
|
154
|
-
self.set_lora = True
|
155
|
-
self.A_buffer_qkv = A_buffer_qkv
|
156
|
-
|
157
|
-
if self.lora_backend.fuse_qkv_lora_b:
|
158
|
-
assert (
|
159
|
-
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
|
160
|
-
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
161
|
-
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
162
|
-
|
163
|
-
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
164
|
-
self.B_buffer_qkv = torch.cat(
|
165
|
-
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
|
166
|
-
).contiguous()
|
167
|
-
|
168
|
-
# Offsets of q/k/v in output dimension
|
169
|
-
self.output_offset = torch.tensor(
|
170
|
-
[
|
171
|
-
0,
|
172
|
-
output_dim_q,
|
173
|
-
output_dim_q + output_dim_kv,
|
174
|
-
output_dim_q + 2 * output_dim_kv,
|
175
|
-
],
|
176
|
-
dtype=torch.int32,
|
177
|
-
device=B_buffer_q.device,
|
178
|
-
)
|
179
|
-
# For computing number of launched blocks
|
180
|
-
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
181
|
-
else:
|
182
|
-
self.B_buffer_qkv = (
|
183
|
-
B_buffer_q,
|
184
|
-
B_buffer_kv,
|
185
|
-
)
|
186
|
-
self.output_offset = None
|
187
|
-
self.max_qkv_out_dim = None
|
188
|
-
|
189
|
-
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
190
|
-
lora_output = self.lora_backend.run_qkv_lora(
|
191
|
-
x,
|
192
|
-
self.A_buffer_qkv,
|
193
|
-
self.B_buffer_qkv,
|
194
|
-
output_offset=self.output_offset,
|
195
|
-
max_qkv_out_dim=self.max_qkv_out_dim,
|
196
|
-
base_output=base_output,
|
197
|
-
scaling=self.scaling,
|
198
|
-
)
|
199
|
-
return (
|
200
|
-
lora_output
|
201
|
-
if self.lora_backend.fuse_output_scaling_add
|
202
|
-
else base_output + lora_output * self.scaling
|
203
|
-
)
|
204
|
-
|
205
|
-
|
206
|
-
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
207
|
-
def __init__(
|
208
|
-
self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
|
209
|
-
) -> None:
|
210
|
-
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
211
|
-
|
212
|
-
def set_lora_info(self, A_buffer, B_buffer):
|
213
|
-
self.set_lora = True
|
214
|
-
self.A_buffer = A_buffer
|
215
|
-
self.B_buffer = B_buffer
|
216
|
-
|
217
|
-
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
218
|
-
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
219
|
-
lora_output = self.lora_backend.run_lora_b_sgemm(
|
220
|
-
lora_a_output,
|
221
|
-
self.B_buffer[0],
|
222
|
-
base_output=base_output,
|
223
|
-
scaling=self.scaling,
|
224
|
-
)
|
225
|
-
return (
|
226
|
-
lora_output
|
227
|
-
if self.lora_backend.fuse_output_scaling_add
|
228
|
-
else base_output + lora_output * self.scaling
|
229
|
-
)
|
230
|
-
|
231
|
-
def forward(self, input_):
|
232
|
-
# duplicate the logic in RowParallelLinear
|
233
|
-
if self.base_layer.input_is_parallel:
|
234
|
-
input_parallel = input_
|
235
|
-
else:
|
236
|
-
tp_rank = get_tensor_model_parallel_rank()
|
237
|
-
splitted_input = split_tensor_along_last_dim(
|
238
|
-
input_, num_partitions=self.base_layer.tp_size
|
239
|
-
)
|
240
|
-
input_parallel = splitted_input[tp_rank].contiguous()
|
241
|
-
output_parallel = self.base_layer.quant_method.apply(
|
242
|
-
self.base_layer, input_parallel
|
243
|
-
)
|
244
|
-
|
245
|
-
if self.set_lora:
|
246
|
-
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
247
|
-
|
248
|
-
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
249
|
-
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
250
|
-
else:
|
251
|
-
output_ = output_parallel
|
252
|
-
|
253
|
-
if not self.base_layer.skip_bias_add:
|
254
|
-
output = (
|
255
|
-
output_ + self.base_layer.bias
|
256
|
-
if self.base_layer.bias is not None
|
257
|
-
else output_
|
258
|
-
)
|
259
|
-
output_bias = None
|
260
|
-
else:
|
261
|
-
output = output_
|
262
|
-
output_bias = self.base_layer.bias
|
263
|
-
return output, output_bias
|
264
|
-
|
265
|
-
|
266
|
-
def get_lora_layer(
|
267
|
-
layer: nn.Module, lora_rank, scaling, lora_backend
|
268
|
-
) -> BaseLayerWithLoRA:
|
269
|
-
supported_layer_types = {
|
270
|
-
# the order matters
|
271
|
-
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
|
272
|
-
QKVParallelLinear: QKVParallelLinearWithLoRA,
|
273
|
-
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
274
|
-
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
275
|
-
RowParallelLinear: RowParallelLinearWithLoRA,
|
276
|
-
}
|
277
|
-
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
278
|
-
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
279
|
-
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
|
280
|
-
return ret
|
281
|
-
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
282
|
-
|
283
|
-
|
284
|
-
def get_mapped_params(module_names):
|
285
|
-
ret = set()
|
286
|
-
for module_name in module_names:
|
287
|
-
ret.add(params_mapping(module_name))
|
288
|
-
return list(ret)
|
289
|
-
|
290
|
-
|
291
34
|
class LoRALayer(nn.Module):
|
292
|
-
def __init__(self, config, base_hf_config):
|
35
|
+
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
|
293
36
|
super().__init__()
|
294
|
-
self.config = config
|
295
|
-
self.base_hf_config = base_hf_config
|
296
|
-
self.weights = {}
|
297
|
-
self.weight_gpu = {}
|
37
|
+
self.config: LoRAConfig = config
|
38
|
+
self.base_hf_config: AutoConfig = base_hf_config
|
39
|
+
self.weights: Dict[str, torch.Tensor] = {}
|
40
|
+
self.weight_gpu: Dict[str, torch.Tensor] = {}
|
298
41
|
|
299
42
|
def load_to_gpu(self):
|
300
43
|
for name, weight in self.weights.items():
|
@@ -306,33 +49,32 @@ class LoRALayer(nn.Module):
|
|
306
49
|
|
307
50
|
|
308
51
|
class LoRAAdapter(nn.Module):
|
309
|
-
def __init__(
|
52
|
+
def __init__(
|
53
|
+
self,
|
54
|
+
uid: str,
|
55
|
+
config: LoRAConfig,
|
56
|
+
base_hf_config: AutoConfig,
|
57
|
+
load_config: LoadConfig,
|
58
|
+
lora_backend: BaseLoRABackend,
|
59
|
+
):
|
310
60
|
super().__init__()
|
311
|
-
self.uid = uid
|
312
|
-
self.config = config
|
61
|
+
self.uid: str = uid
|
62
|
+
self.config: LoRAConfig = config
|
313
63
|
assert self.config.hf_config["peft_type"].lower() == "lora"
|
314
|
-
self.base_hf_config = base_hf_config
|
315
|
-
self.load_config = load_config
|
316
|
-
self.lora_backend = lora_backend
|
317
|
-
self.scaling = self.config.lora_alpha / self.config.r
|
64
|
+
self.base_hf_config: AutoConfig = base_hf_config
|
65
|
+
self.load_config: LoadConfig = load_config
|
66
|
+
self.lora_backend: BaseLoRABackend = lora_backend
|
67
|
+
self.scaling: float = self.config.lora_alpha / self.config.r
|
318
68
|
|
319
|
-
self.layers = nn.ModuleList(
|
69
|
+
self.layers: List[LoRALayer] = nn.ModuleList(
|
320
70
|
[
|
321
71
|
LoRALayer(config, base_hf_config)
|
322
72
|
for i in range(base_hf_config.num_hidden_layers)
|
323
73
|
]
|
324
74
|
)
|
325
75
|
|
326
|
-
self.weights = {}
|
327
|
-
self.weights_gpu = {}
|
328
|
-
|
329
|
-
def get_stacked_multiply(self, module_name):
|
330
|
-
stacked_rank = {
|
331
|
-
"qkv_proj": 3,
|
332
|
-
"kv_proj": 2,
|
333
|
-
"gate_up_proj": 2,
|
334
|
-
}
|
335
|
-
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
76
|
+
self.weights: Dict[str, torch.Tensor] = {}
|
77
|
+
self.weights_gpu: Dict[str, torch.Tensor] = {}
|
336
78
|
|
337
79
|
def load_to_gpu(self):
|
338
80
|
for name, weight in self.weights.items():
|
@@ -367,44 +109,77 @@ class LoRAAdapter(nn.Module):
|
|
367
109
|
for i in range(self.base_hf_config.num_hidden_layers):
|
368
110
|
layer = self.layers[i]
|
369
111
|
weight_names = [name for name, _ in layer.weights.items()]
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
)
|
409
|
-
|
410
|
-
|
112
|
+
self.stack_qkv_proj(weight_names, layer.weights)
|
113
|
+
self.stack_gate_up_proj(weight_names, layer.weights)
|
114
|
+
|
115
|
+
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
|
116
|
+
|
117
|
+
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
|
118
|
+
target_module = set()
|
119
|
+
for weight_name in weight_names:
|
120
|
+
if "k_proj" in weight_name:
|
121
|
+
target_module.add("k_proj")
|
122
|
+
if "q_proj" in weight_name:
|
123
|
+
target_module.add("q_proj")
|
124
|
+
if "v_proj" in weight_name:
|
125
|
+
target_module.add("v_proj")
|
126
|
+
if len(target_module) == 0:
|
127
|
+
return
|
128
|
+
|
129
|
+
for weight_name in weight_names:
|
130
|
+
# We assume every lora adaptor should contain lora modules for q_proj
|
131
|
+
if "q_proj" in weight_name:
|
132
|
+
q_name = weight_name
|
133
|
+
k_name = weight_name.replace("q_proj", "k_proj")
|
134
|
+
v_name = weight_name.replace("q_proj", "v_proj")
|
135
|
+
kv_name = weight_name.replace("q_proj", "kv_proj")
|
136
|
+
qkv_name = weight_name.replace("q_proj", "qkv_proj")
|
137
|
+
|
138
|
+
# If k_proj doesn't have lora, initialize it to zero
|
139
|
+
k_proj_weight = (
|
140
|
+
weights[k_name]
|
141
|
+
if "k_proj" in target_module
|
142
|
+
else torch.zeros_like(weights[v_name])
|
143
|
+
)
|
144
|
+
if "lora_A" in weight_name:
|
145
|
+
weights[qkv_name] = torch.cat(
|
146
|
+
(
|
147
|
+
weights[q_name],
|
148
|
+
k_proj_weight,
|
149
|
+
weights[v_name],
|
150
|
+
),
|
151
|
+
0,
|
152
|
+
)
|
153
|
+
weights.pop(q_name)
|
154
|
+
if "k_proj" in target_module:
|
155
|
+
weights.pop(k_name)
|
156
|
+
weights.pop(v_name)
|
157
|
+
else:
|
158
|
+
weights[kv_name] = torch.stack(
|
159
|
+
[
|
160
|
+
k_proj_weight,
|
161
|
+
weights[v_name],
|
162
|
+
],
|
163
|
+
dim=0,
|
164
|
+
)
|
165
|
+
if "k_proj" in target_module:
|
166
|
+
weights.pop(k_name)
|
167
|
+
weights.pop(v_name)
|
168
|
+
|
169
|
+
def stack_gate_up_proj(
|
170
|
+
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
171
|
+
):
|
172
|
+
for weight_name in weight_names:
|
173
|
+
if "gate_proj" in weight_name:
|
174
|
+
up_name = weight_name.replace("gate_proj", "up_proj")
|
175
|
+
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
176
|
+
if "lora_A" in weight_name:
|
177
|
+
weights[gate_up_name] = torch.cat(
|
178
|
+
(weights[weight_name], weights[up_name]), 0
|
179
|
+
)
|
180
|
+
else:
|
181
|
+
weights[gate_up_name] = torch.stack(
|
182
|
+
[weights[weight_name], weights[up_name]], dim=0
|
183
|
+
)
|
184
|
+
weights.pop(weight_name)
|
185
|
+
weights.pop(up_name)
|