orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -1
- orbit/callback.py +4 -3
- orbit/dataset/__init__.py +1 -0
- orbit/dataset/cogn.py +138 -0
- orbit/dataset/data/cogn_en.jsonl +45 -0
- orbit/dataset/data/cogn_zh.jsonl +113 -0
- orbit/engine.py +210 -146
- orbit/kit/__init__.py +2 -0
- orbit/kit/interface.py +154 -0
- orbit/kit/wrapper.py +157 -0
- orbit/model/__init__.py +5 -0
- orbit/model/base.py +125 -0
- orbit/model/block/__init__.py +34 -0
- orbit/model/block/attention.py +265 -0
- orbit/model/block/bio.py +537 -0
- orbit/model/block/codebook.py +122 -0
- orbit/model/block/conv.py +505 -0
- orbit/model/block/embedding.py +252 -0
- orbit/model/block/film.py +176 -0
- orbit/model/block/fusion.py +335 -0
- orbit/model/block/gate.py +334 -0
- orbit/model/block/lora.py +776 -0
- orbit/model/block/mlp.py +68 -0
- orbit/model/block/moe.py +94 -0
- orbit/model/block/tcn.py +99 -0
- orbit/model/config.py +62 -0
- orbit/model/kit/__init__.py +6 -0
- orbit/model/kit/discriminator.py +46 -0
- orbit/model/kit/losses.py +193 -0
- orbit/model/motif/__init__.py +0 -0
- orbit/model/motif/vision/__init__.py +0 -0
- orbit/model/motif/vision/v1.py +645 -0
- orbit/model/registry.py +53 -0
- orbit/optim/__init__.py +2 -2
- orbit/optim/sam.py +10 -3
- orbit/plugin/__init__.py +12 -8
- orbit/plugin/board.py +1 -2
- orbit/plugin/checkpoint.py +137 -62
- orbit/plugin/classification.py +2 -2
- orbit/plugin/display_model.py +1 -2
- orbit/plugin/early_stopping.py +1 -2
- orbit/plugin/ema.py +1 -2
- orbit/plugin/gradient_accumulation.py +1 -2
- orbit/plugin/lora.py +346 -0
- orbit/plugin/memory_estimator.py +1 -2
- orbit/plugin/warmup.py +1 -2
- orbit/utils/__init__.py +24 -1
- orbit/utils/cuda.py +10 -0
- orbit/utils/freeze.py +61 -17
- orbit/utils/image.py +164 -0
- orbit/utils/initialization.py +184 -94
- orbit/utils/layer_io.py +66 -7
- orbit/utils/lora.py +480 -0
- orbit/utils/moe.py +55 -0
- orbit/utils/seed.py +3 -19
- orbit/utils/sft.py +93 -0
- orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
- orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
- orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,776 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from orbit.model import BaseBlock, register_model
|
|
7
|
+
|
|
8
|
+
@register_model()
|
|
9
|
+
class LinearLoRA(BaseBlock):
|
|
10
|
+
'''实现 Linear 层的 LoRA (Low-Rank Adaptation)。
|
|
11
|
+
|
|
12
|
+
LoRA 通过注入可训练的低秩矩阵来适应预训练权重,同时冻结原始权重。
|
|
13
|
+
计算公式: h = W_0 x + B A x * scaling
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
original_layer (nn.Linear): 原始的 Linear 层。
|
|
17
|
+
r (int): LoRA 的秩。
|
|
18
|
+
lora_alpha (int): LoRA 的缩放系数。
|
|
19
|
+
scaling (float): 实际缩放比例 (lora_alpha / r)。
|
|
20
|
+
gate (bool): 是否使用 Gated LoRA。
|
|
21
|
+
lora_gate (nn.Parameter): 门控参数。
|
|
22
|
+
dora (bool): 是否使用 DoRA。
|
|
23
|
+
dora_m (nn.Parameter): DoRA 的幅值向量。
|
|
24
|
+
merged (bool): 权重是否已合并。
|
|
25
|
+
lora_a (nn.Parameter): 降维矩阵 A。
|
|
26
|
+
lora_b (nn.Parameter): 升维矩阵 B。
|
|
27
|
+
'''
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
original_layer: nn.Linear,
|
|
31
|
+
r: int = 8,
|
|
32
|
+
lora_alpha: int = 16,
|
|
33
|
+
lora_dropout: float = 0.05,
|
|
34
|
+
merge_weights: bool = False,
|
|
35
|
+
gate: bool = False,
|
|
36
|
+
dora: bool = False,
|
|
37
|
+
gradient_checkpointing: bool = False
|
|
38
|
+
):
|
|
39
|
+
'''初始化 LinearLoRA。
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
original_layer (nn.Linear): 原始的 Linear 层。
|
|
43
|
+
r (int): LoRA 的秩。默认为 8。
|
|
44
|
+
lora_alpha (int): LoRA 的缩放系数。默认为 16。
|
|
45
|
+
lora_dropout (float): Dropout 概率。默认为 0.05。
|
|
46
|
+
merge_weights (bool): 初始化时是否将 LoRA 权重合并到原始权重中。默认为 False。
|
|
47
|
+
gate (bool): 是否使用 Gated LoRA。默认为 False。
|
|
48
|
+
dora (bool): 是否使用 DoRA。默认为 False。
|
|
49
|
+
gradient_checkpointing (bool): 是否使用梯度检查点。默认为 False。
|
|
50
|
+
'''
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.gradient_checkpointing = gradient_checkpointing
|
|
53
|
+
|
|
54
|
+
self.in_features = original_layer.in_features
|
|
55
|
+
self.out_features = original_layer.out_features
|
|
56
|
+
|
|
57
|
+
self.original_layer = original_layer
|
|
58
|
+
for p in self.original_layer.parameters():
|
|
59
|
+
p.requires_grad = False
|
|
60
|
+
|
|
61
|
+
self.r = r
|
|
62
|
+
self.lora_alpha = lora_alpha
|
|
63
|
+
self.scaling = lora_alpha / r
|
|
64
|
+
self.merged = False
|
|
65
|
+
self.gate = gate
|
|
66
|
+
|
|
67
|
+
if r > 0:
|
|
68
|
+
self.lora_gate = nn.Parameter(torch.tensor([1.0])) if gate else None
|
|
69
|
+
self.lora_a = nn.Parameter(torch.zeros((r, self.in_features)))
|
|
70
|
+
self.lora_b = nn.Parameter(torch.zeros((self.out_features, r)))
|
|
71
|
+
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
|
72
|
+
else:
|
|
73
|
+
self.lora_a = None
|
|
74
|
+
self.lora_b = None
|
|
75
|
+
self.lora_dropout = None
|
|
76
|
+
|
|
77
|
+
self.reset_parameters()
|
|
78
|
+
|
|
79
|
+
self.dora = dora
|
|
80
|
+
if dora and r > 0:
|
|
81
|
+
self.dora_m = nn.Parameter(self.original_layer.weight.norm(p=2, dim=0, keepdim=True))
|
|
82
|
+
else:
|
|
83
|
+
self.dora_m = None
|
|
84
|
+
|
|
85
|
+
# 确保 LoRA 参数与原始层在同一设备上
|
|
86
|
+
if hasattr(self.original_layer, 'weight'):
|
|
87
|
+
self.to(self.original_layer.weight.device)
|
|
88
|
+
|
|
89
|
+
if merge_weights: self.merge()
|
|
90
|
+
|
|
91
|
+
def reset_parameters(self):
|
|
92
|
+
'''重置 LoRA 参数。
|
|
93
|
+
|
|
94
|
+
A 矩阵使用 Kaiming Uniform 初始化,B 矩阵初始化为零。
|
|
95
|
+
这样可以确保初始状态下 LoRA 分支的输出为零,不影响模型原有输出。
|
|
96
|
+
'''
|
|
97
|
+
if self.r > 0:
|
|
98
|
+
nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5))
|
|
99
|
+
nn.init.zeros_(self.lora_b)
|
|
100
|
+
|
|
101
|
+
def train(self, mode: bool = True):
|
|
102
|
+
'''设置训练模式。
|
|
103
|
+
|
|
104
|
+
如果进入训练模式,确保权重未合并。
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
mode (bool): 是否为训练模式。
|
|
108
|
+
'''
|
|
109
|
+
super().train(mode)
|
|
110
|
+
if not mode: return
|
|
111
|
+
if self.merged: self.unmerge()
|
|
112
|
+
|
|
113
|
+
def merge(self):
|
|
114
|
+
'''将 LoRA 权重合并到原始层权重中。
|
|
115
|
+
|
|
116
|
+
用于推理加速。
|
|
117
|
+
DoRA: 合并后无法恢复原始权重(除非存储原始权重副本,但这违背了LoRA节省显存的初衷)。
|
|
118
|
+
'''
|
|
119
|
+
if self.r > 0 and not self.merged:
|
|
120
|
+
if self.dora:
|
|
121
|
+
# Calculate full weight W' = W0 + BA * scaling
|
|
122
|
+
delta_w = (self.lora_b @ self.lora_a) * self.scaling
|
|
123
|
+
if self.gate: delta_w *= self.lora_gate
|
|
124
|
+
weight = self.original_layer.weight + delta_w
|
|
125
|
+
|
|
126
|
+
# Normalize and scale: W_final = m * W' / ||W'||
|
|
127
|
+
norm = weight.norm(p=2, dim=1, keepdim=True)
|
|
128
|
+
weight = (weight / (norm + 1e-6)) * self.dora_m
|
|
129
|
+
|
|
130
|
+
# Update original weight (Destructive!)
|
|
131
|
+
self.original_layer.weight.data = weight.to(self.original_layer.weight.dtype)
|
|
132
|
+
else:
|
|
133
|
+
# W_new = W_old + B @ A * scaling
|
|
134
|
+
delta_w = (self.lora_b @ self.lora_a) * self.scaling
|
|
135
|
+
if self.gate: delta_w *= self.lora_gate
|
|
136
|
+
self.original_layer.weight.data += delta_w.to(self.original_layer.weight.dtype)
|
|
137
|
+
|
|
138
|
+
self.merged = True
|
|
139
|
+
|
|
140
|
+
def unmerge(self):
|
|
141
|
+
'''从原始权重中减去 LoRA 权重。
|
|
142
|
+
|
|
143
|
+
用于恢复原始权重或继续训练。
|
|
144
|
+
注意:DoRA 模式下不支持 unmerge。
|
|
145
|
+
'''
|
|
146
|
+
if self.r > 0 and self.merged:
|
|
147
|
+
if self.dora:
|
|
148
|
+
print("Warning: DoRA weights cannot be unmerged exactly. Original weights are lost.")
|
|
149
|
+
pass
|
|
150
|
+
else:
|
|
151
|
+
delta_w = (self.lora_b @ self.lora_a) * self.scaling
|
|
152
|
+
if self.gate: delta_w *= self.lora_gate
|
|
153
|
+
self.original_layer.weight.data -= delta_w
|
|
154
|
+
|
|
155
|
+
self.merged = False
|
|
156
|
+
|
|
157
|
+
def _forward_impl(self, x: torch.Tensor):
|
|
158
|
+
if self.r > 0 and self.merged:
|
|
159
|
+
return self.original_layer(x)
|
|
160
|
+
|
|
161
|
+
if self.dora and self.r > 0:
|
|
162
|
+
# DoRA: W_final = m * (W0 + BA) / ||W0 + BA||
|
|
163
|
+
delta_w = (self.lora_b @ self.lora_a) * self.scaling
|
|
164
|
+
if self.gate: delta_w *= self.lora_gate
|
|
165
|
+
|
|
166
|
+
# Reconstruct full weight for calculation
|
|
167
|
+
weight = self.original_layer.weight + delta_w
|
|
168
|
+
norm = weight.norm(p=2, dim=1, keepdim=True)
|
|
169
|
+
weight = (weight / (norm + 1e-6)) * self.dora_m
|
|
170
|
+
|
|
171
|
+
return F.linear(x, weight.to(x.dtype), self.original_layer.bias)
|
|
172
|
+
|
|
173
|
+
result = self.original_layer(x)
|
|
174
|
+
|
|
175
|
+
if self.r > 0:
|
|
176
|
+
# x shape: (batch, ..., in)
|
|
177
|
+
# lora_a shape: (r, in) -> x @ A.T -> (batch, ..., r)
|
|
178
|
+
# lora_b shape: (out, r) -> result @ B.T -> (batch, ..., out)
|
|
179
|
+
x_dropped = self.lora_dropout(x)
|
|
180
|
+
lora_out = (x_dropped @ self.lora_a.transpose(0, 1) @ self.lora_b.transpose(0, 1)) * self.scaling
|
|
181
|
+
if self.gate: lora_out *= self.lora_gate
|
|
182
|
+
result += lora_out
|
|
183
|
+
|
|
184
|
+
return result
|
|
185
|
+
|
|
186
|
+
def forward(self, x: torch.Tensor):
|
|
187
|
+
'''前向传播。
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
x (torch.Tensor): 输入张量。
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
torch.Tensor: 输出张量。
|
|
194
|
+
'''
|
|
195
|
+
if self.gradient_checkpointing and self.training:
|
|
196
|
+
if x.requires_grad:
|
|
197
|
+
return self.checkpoint(self._forward_impl, x)
|
|
198
|
+
else:
|
|
199
|
+
dummy = torch.tensor(0.0, requires_grad=True, device=x.device)
|
|
200
|
+
return self.checkpoint(lambda d, x: self._forward_impl(x), dummy, x)
|
|
201
|
+
return self._forward_impl(x)
|
|
202
|
+
|
|
203
|
+
def __repr__(self):
|
|
204
|
+
prefix = 'Gated' if self.gate else ''
|
|
205
|
+
suffix = 'DoRA' if self.dora else 'LoRA'
|
|
206
|
+
return f'{self.__class__.__name__}(type={prefix}{suffix}, in_features={self.in_features}, out_features={self.out_features}, r={self.r}, merged={self.merged})'
|
|
207
|
+
|
|
208
|
+
@register_model()
|
|
209
|
+
class Conv2dLoRA(BaseBlock):
|
|
210
|
+
'''实现 Conv2d 层的 LoRA (Low-Rank Adaptation)。
|
|
211
|
+
|
|
212
|
+
使用两个连续的卷积层模拟低秩矩阵分解:
|
|
213
|
+
1. A 层: 降低通道数到 r,保持 kernel_size。
|
|
214
|
+
2. B 层: 恢复通道数,使用 1x1 kernel。
|
|
215
|
+
|
|
216
|
+
Attributes:
|
|
217
|
+
original_layer (nn.Conv2d): 原始的 Conv2d 层。
|
|
218
|
+
r (int): LoRA 的秩。
|
|
219
|
+
lora_alpha (int): LoRA 的缩放系数。
|
|
220
|
+
scaling (float): 实际缩放比例 (lora_alpha / r)。
|
|
221
|
+
gate (bool): 是否使用 Gated LoRA。
|
|
222
|
+
lora_gate (nn.Parameter): 门控参数。
|
|
223
|
+
dora (bool): 是否使用 DoRA。
|
|
224
|
+
dora_m (nn.Parameter): DoRA 的幅值向量。
|
|
225
|
+
merged (bool): 权重是否已合并。
|
|
226
|
+
lora_a (nn.Conv2d): 降维卷积层。
|
|
227
|
+
lora_b (nn.Conv2d): 升维卷积层 (1x1)。
|
|
228
|
+
'''
|
|
229
|
+
def __init__(
|
|
230
|
+
self,
|
|
231
|
+
original_layer: nn.Conv2d,
|
|
232
|
+
r: int = 8,
|
|
233
|
+
lora_alpha: int = 16,
|
|
234
|
+
lora_dropout: float = 0.05,
|
|
235
|
+
merge_weights: bool = False,
|
|
236
|
+
gate: bool = False,
|
|
237
|
+
dora: bool = False,
|
|
238
|
+
gradient_checkpointing: bool = False
|
|
239
|
+
):
|
|
240
|
+
'''初始化 Conv2dLoRA。
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
original_layer (nn.Conv2d): 原始的 Conv2d 层。
|
|
244
|
+
r (int): LoRA 的秩。默认为 8。
|
|
245
|
+
lora_alpha (int): LoRA 的缩放系数。默认为 16。
|
|
246
|
+
lora_dropout (float): Dropout 概率。默认为 0.05。
|
|
247
|
+
merge_weights (bool): 初始化时是否将 LoRA 权重合并到原始权重中。默认为 False。
|
|
248
|
+
gate (bool): 是否使用 Gated LoRA。默认为 False。
|
|
249
|
+
dora (bool): 是否使用 DoRA。默认为 False。
|
|
250
|
+
gradient_checkpointing (bool): 是否使用梯度检查点。默认为 False。
|
|
251
|
+
'''
|
|
252
|
+
super().__init__()
|
|
253
|
+
self.gradient_checkpointing = gradient_checkpointing
|
|
254
|
+
self.original_layer = original_layer
|
|
255
|
+
self.in_channels = original_layer.in_channels
|
|
256
|
+
self.out_channels = original_layer.out_channels
|
|
257
|
+
self.kernel_size = original_layer.kernel_size
|
|
258
|
+
self.stride = original_layer.stride
|
|
259
|
+
self.padding = original_layer.padding
|
|
260
|
+
self.dilation = original_layer.dilation
|
|
261
|
+
self.groups = original_layer.groups
|
|
262
|
+
|
|
263
|
+
for p in self.original_layer.parameters():
|
|
264
|
+
p.requires_grad = False
|
|
265
|
+
|
|
266
|
+
self.r = r
|
|
267
|
+
self.lora_alpha = lora_alpha
|
|
268
|
+
self.scaling = lora_alpha / r
|
|
269
|
+
self.merged = False
|
|
270
|
+
self.gate = gate
|
|
271
|
+
|
|
272
|
+
if r > 0:
|
|
273
|
+
self.lora_gate = nn.Parameter(torch.tensor([1.0])) if gate else None
|
|
274
|
+
self.lora_a = nn.Conv2d(
|
|
275
|
+
self.in_channels, r,
|
|
276
|
+
kernel_size=self.kernel_size,
|
|
277
|
+
stride=self.stride,
|
|
278
|
+
padding=self.padding,
|
|
279
|
+
dilation=self.dilation,
|
|
280
|
+
groups=self.groups,
|
|
281
|
+
bias=False
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
self.lora_b = nn.Conv2d(
|
|
285
|
+
r, self.out_channels,
|
|
286
|
+
kernel_size=1,
|
|
287
|
+
stride=1,
|
|
288
|
+
padding=0,
|
|
289
|
+
bias=False
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
|
293
|
+
else:
|
|
294
|
+
self.lora_a = None
|
|
295
|
+
self.lora_b = None
|
|
296
|
+
|
|
297
|
+
self.reset_parameters()
|
|
298
|
+
|
|
299
|
+
self.dora = dora
|
|
300
|
+
if dora and r > 0:
|
|
301
|
+
# Conv2d weight: (out, in, k, k) -> norm dim=(1,2,3) for each output channel
|
|
302
|
+
self.dora_m = nn.Parameter(
|
|
303
|
+
self.original_layer.weight.norm(p=2, dim=(1, 2, 3), keepdim=True)
|
|
304
|
+
)
|
|
305
|
+
else:
|
|
306
|
+
self.dora_m = None
|
|
307
|
+
|
|
308
|
+
if hasattr(self.original_layer, 'weight'):
|
|
309
|
+
self.to(self.original_layer.weight.device)
|
|
310
|
+
|
|
311
|
+
if merge_weights: self.merge()
|
|
312
|
+
|
|
313
|
+
def reset_parameters(self):
|
|
314
|
+
'''重置 LoRA 参数。
|
|
315
|
+
|
|
316
|
+
A 卷积层使用 Kaiming Uniform 初始化,B 卷积层初始化为零。
|
|
317
|
+
'''
|
|
318
|
+
if self.r > 0:
|
|
319
|
+
# A: Kaiming 初始化
|
|
320
|
+
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
|
|
321
|
+
# B: 0 初始化
|
|
322
|
+
nn.init.zeros_(self.lora_b.weight)
|
|
323
|
+
|
|
324
|
+
def train(self, mode: bool = True):
|
|
325
|
+
'''设置训练模式。
|
|
326
|
+
|
|
327
|
+
如果进入训练模式,确保权重未合并。
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
mode (bool): 是否为训练模式。
|
|
331
|
+
'''
|
|
332
|
+
super().train(mode)
|
|
333
|
+
if mode and self.merged: self.unmerge()
|
|
334
|
+
|
|
335
|
+
def merge(self):
|
|
336
|
+
'''将 LoRA 权重合并到原始卷积层权重中。
|
|
337
|
+
|
|
338
|
+
使用 einsum 计算 LoRA 分支的等效卷积核并加到原始权重上。
|
|
339
|
+
'''
|
|
340
|
+
if self.r > 0 and not self.merged:
|
|
341
|
+
weight_b = self.lora_b.weight.squeeze(3).squeeze(2) # (out, r)
|
|
342
|
+
weight_a = self.lora_a.weight # (r, in, k, k)
|
|
343
|
+
|
|
344
|
+
# i: out_channels, j: r, k: in_channels, m, n: kernel dims
|
|
345
|
+
delta_w = torch.einsum('ij, jkmn -> ikmn', weight_b, weight_a) * self.scaling
|
|
346
|
+
if self.gate: delta_w *= self.lora_gate
|
|
347
|
+
|
|
348
|
+
if self.dora:
|
|
349
|
+
weight = self.original_layer.weight + delta_w
|
|
350
|
+
norm = weight.norm(p=2, dim=(1, 2, 3), keepdim=True)
|
|
351
|
+
weight = (weight / (norm + 1e-6)) * self.dora_m
|
|
352
|
+
self.original_layer.weight.data = weight.to(self.original_layer.weight.dtype)
|
|
353
|
+
else:
|
|
354
|
+
self.original_layer.weight.data += delta_w
|
|
355
|
+
|
|
356
|
+
self.merged = True
|
|
357
|
+
|
|
358
|
+
def unmerge(self):
|
|
359
|
+
'''从原始权重中减去 LoRA 权重。'''
|
|
360
|
+
if self.r > 0 and self.merged:
|
|
361
|
+
if self.dora:
|
|
362
|
+
print("Warning: DoRA weights cannot be unmerged exactly. Original weights are lost.")
|
|
363
|
+
else:
|
|
364
|
+
weight_b = self.lora_b.weight.squeeze(3).squeeze(2)
|
|
365
|
+
weight_a = self.lora_a.weight
|
|
366
|
+
delta_w = torch.einsum('ij, jkmn -> ikmn', weight_b, weight_a) * self.scaling
|
|
367
|
+
if self.gate: delta_w *= self.lora_gate
|
|
368
|
+
self.original_layer.weight.data -= delta_w
|
|
369
|
+
|
|
370
|
+
self.merged = False
|
|
371
|
+
|
|
372
|
+
def _forward_impl(self, x: torch.Tensor):
|
|
373
|
+
if self.r > 0 and self.merged:
|
|
374
|
+
return self.original_layer(x)
|
|
375
|
+
|
|
376
|
+
if self.dora and self.r > 0:
|
|
377
|
+
weight_b = self.lora_b.weight.squeeze(3).squeeze(2) # (out, r)
|
|
378
|
+
weight_a = self.lora_a.weight # (r, in, k, k)
|
|
379
|
+
delta_w = torch.einsum('ij, jkmn -> ikmn', weight_b, weight_a) * self.scaling
|
|
380
|
+
if self.gate: delta_w *= self.lora_gate
|
|
381
|
+
|
|
382
|
+
weight = self.original_layer.weight + delta_w
|
|
383
|
+
norm = weight.norm(p=2, dim=(1, 2, 3), keepdim=True)
|
|
384
|
+
weight = (weight / (norm + 1e-6)) * self.dora_m
|
|
385
|
+
|
|
386
|
+
return F.conv2d(
|
|
387
|
+
x, weight.to(x.dtype), self.original_layer.bias,
|
|
388
|
+
self.stride, self.padding, self.dilation, self.groups
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
result = self.original_layer(x)
|
|
392
|
+
|
|
393
|
+
if self.r > 0:
|
|
394
|
+
x_dropped = self.lora_dropout(x)
|
|
395
|
+
# Input -> Conv(in, r)[spatial] -> Conv(r, out)[1x1]
|
|
396
|
+
lora_out = self.lora_b(self.lora_a(x_dropped)) * self.scaling
|
|
397
|
+
if self.gate: lora_out *= self.lora_gate
|
|
398
|
+
result += lora_out
|
|
399
|
+
|
|
400
|
+
return result
|
|
401
|
+
|
|
402
|
+
def forward(self, x: torch.Tensor):
|
|
403
|
+
'''前向传播。
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
x (torch.Tensor): 输入张量。
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
torch.Tensor: 输出张量。
|
|
410
|
+
'''
|
|
411
|
+
if self.gradient_checkpointing and self.training:
|
|
412
|
+
if x.requires_grad:
|
|
413
|
+
return self.checkpoint(self._forward_impl, x)
|
|
414
|
+
else:
|
|
415
|
+
dummy = torch.tensor(0.0, requires_grad=True, device=x.device)
|
|
416
|
+
return self.checkpoint(lambda d, x: self._forward_impl(x), dummy, x)
|
|
417
|
+
return self._forward_impl(x)
|
|
418
|
+
|
|
419
|
+
def __repr__(self):
|
|
420
|
+
prefix = 'Gated' if self.gate else ''
|
|
421
|
+
suffix = 'DoRA' if self.dora else 'LoRA'
|
|
422
|
+
return f'{self.__class__.__name__}(type={prefix}{suffix}, in_channels={self.in_channels}, out_channels={self.out_channels}, kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}, dilation={self.dilation}, groups={self.groups}, r={self.r}, merged={self.merged})'
|
|
423
|
+
|
|
424
|
+
@register_model()
|
|
425
|
+
class Conv1dLoRA(BaseBlock):
|
|
426
|
+
'''实现 Conv1d 层的 LoRA (Low-Rank Adaptation)。
|
|
427
|
+
|
|
428
|
+
使用两个连续的卷积层模拟低秩矩阵分解:
|
|
429
|
+
1. A 层: 降低通道数到 r,保持 kernel_size。
|
|
430
|
+
2. B 层: 恢复通道数,使用 1x1 kernel。
|
|
431
|
+
|
|
432
|
+
Attributes:
|
|
433
|
+
original_layer (nn.Conv1d): 原始的 Conv1d 层。
|
|
434
|
+
r (int): LoRA 的秩。
|
|
435
|
+
lora_alpha (int): LoRA 的缩放系数。
|
|
436
|
+
scaling (float): 实际缩放比例 (lora_alpha / r)。
|
|
437
|
+
gate (bool): 是否使用 Gated LoRA。
|
|
438
|
+
lora_gate (nn.Parameter): 门控参数。
|
|
439
|
+
dora (bool): 是否使用 DoRA。
|
|
440
|
+
dora_m (nn.Parameter): DoRA 的幅值向量。
|
|
441
|
+
merged (bool): 权重是否已合并。
|
|
442
|
+
lora_a (nn.Conv1d): 降维卷积层。
|
|
443
|
+
lora_b (nn.Conv1d): 升维卷积层 (1x1)。
|
|
444
|
+
'''
|
|
445
|
+
def __init__(
|
|
446
|
+
self,
|
|
447
|
+
original_layer: nn.Conv1d,
|
|
448
|
+
r: int = 8,
|
|
449
|
+
lora_alpha: int = 16,
|
|
450
|
+
lora_dropout: float = 0.05,
|
|
451
|
+
merge_weights: bool = False,
|
|
452
|
+
gate: bool = False,
|
|
453
|
+
dora: bool = False,
|
|
454
|
+
gradient_checkpointing: bool = False
|
|
455
|
+
):
|
|
456
|
+
'''初始化 Conv1dLoRA。
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
original_layer (nn.Conv1d): 原始的 Conv1d 层。
|
|
460
|
+
r (int): LoRA 的秩。默认为 8。
|
|
461
|
+
lora_alpha (int): LoRA 的缩放系数。默认为 16。
|
|
462
|
+
lora_dropout (float): Dropout 概率。默认为 0.05。
|
|
463
|
+
merge_weights (bool): 初始化时是否将 LoRA 权重合并到原始权重中。默认为 False。
|
|
464
|
+
gate (bool): 是否使用 Gated LoRA。默认为 False。
|
|
465
|
+
dora (bool): 是否使用 DoRA。默认为 False。
|
|
466
|
+
gradient_checkpointing (bool): 是否使用梯度检查点。默认为 False。
|
|
467
|
+
'''
|
|
468
|
+
super().__init__()
|
|
469
|
+
self.gradient_checkpointing = gradient_checkpointing
|
|
470
|
+
self.original_layer = original_layer
|
|
471
|
+
self.in_channels = original_layer.in_channels
|
|
472
|
+
self.out_channels = original_layer.out_channels
|
|
473
|
+
self.kernel_size = original_layer.kernel_size[0] # Conv1d kernel_size 是 tuple
|
|
474
|
+
self.stride = original_layer.stride[0]
|
|
475
|
+
self.padding = original_layer.padding[0]
|
|
476
|
+
self.dilation = original_layer.dilation[0]
|
|
477
|
+
self.groups = original_layer.groups
|
|
478
|
+
|
|
479
|
+
# 冻结原层
|
|
480
|
+
for p in self.original_layer.parameters():
|
|
481
|
+
p.requires_grad = False
|
|
482
|
+
|
|
483
|
+
self.r = r
|
|
484
|
+
self.lora_alpha = lora_alpha
|
|
485
|
+
self.scaling = lora_alpha / r
|
|
486
|
+
self.merged = False
|
|
487
|
+
self.gate = gate
|
|
488
|
+
|
|
489
|
+
if r > 0:
|
|
490
|
+
self.lora_gate = nn.Parameter(torch.tensor([1.0])) if gate else None
|
|
491
|
+
# A: 降维 + 空间(时序)卷积
|
|
492
|
+
self.lora_a = nn.Conv1d(
|
|
493
|
+
self.in_channels, r,
|
|
494
|
+
kernel_size=self.kernel_size,
|
|
495
|
+
stride=self.stride,
|
|
496
|
+
padding=self.padding,
|
|
497
|
+
dilation=self.dilation,
|
|
498
|
+
groups=self.groups,
|
|
499
|
+
bias=False
|
|
500
|
+
)
|
|
501
|
+
# B: 升维 + 点卷积 (kernel=1)
|
|
502
|
+
self.lora_b = nn.Conv1d(
|
|
503
|
+
r, self.out_channels,
|
|
504
|
+
kernel_size=1,
|
|
505
|
+
stride=1,
|
|
506
|
+
padding=0,
|
|
507
|
+
bias=False
|
|
508
|
+
)
|
|
509
|
+
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
|
510
|
+
else:
|
|
511
|
+
self.lora_a = None
|
|
512
|
+
self.lora_b = None
|
|
513
|
+
|
|
514
|
+
self.reset_parameters()
|
|
515
|
+
|
|
516
|
+
self.dora = dora
|
|
517
|
+
if dora and r > 0:
|
|
518
|
+
# Conv1d weight: (out, in, k) -> norm dim=(1,2)
|
|
519
|
+
self.dora_m = nn.Parameter(
|
|
520
|
+
self.original_layer.weight.norm(p=2, dim=(1, 2), keepdim=True)
|
|
521
|
+
)
|
|
522
|
+
else:
|
|
523
|
+
self.dora_m = None
|
|
524
|
+
|
|
525
|
+
if hasattr(self.original_layer, 'weight'):
|
|
526
|
+
self.to(self.original_layer.weight.device)
|
|
527
|
+
|
|
528
|
+
if merge_weights:
|
|
529
|
+
self.merge()
|
|
530
|
+
|
|
531
|
+
def reset_parameters(self):
|
|
532
|
+
if self.r > 0:
|
|
533
|
+
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
|
|
534
|
+
nn.init.zeros_(self.lora_b.weight)
|
|
535
|
+
|
|
536
|
+
def merge(self):
|
|
537
|
+
if self.r > 0 and not self.merged:
|
|
538
|
+
# B: (out, r, 1) -> (out, r)
|
|
539
|
+
weight_b = self.lora_b.weight.squeeze(2)
|
|
540
|
+
# A: (r, in, k)
|
|
541
|
+
weight_a = self.lora_a.weight
|
|
542
|
+
|
|
543
|
+
# einsum: ij(out,r), jkn(r,in,k) -> ikn(out,in,k)
|
|
544
|
+
delta_w = torch.einsum('ij, jkn -> ikn', weight_b, weight_a) * self.scaling
|
|
545
|
+
if self.gate: delta_w *= self.lora_gate
|
|
546
|
+
|
|
547
|
+
if self.dora:
|
|
548
|
+
weight = self.original_layer.weight + delta_w
|
|
549
|
+
norm = weight.norm(p=2, dim=(1, 2), keepdim=True)
|
|
550
|
+
weight = (weight / (norm + 1e-6)) * self.dora_m
|
|
551
|
+
self.original_layer.weight.data = weight.to(self.original_layer.weight.dtype)
|
|
552
|
+
else:
|
|
553
|
+
self.original_layer.weight.data += delta_w
|
|
554
|
+
|
|
555
|
+
self.merged = True
|
|
556
|
+
|
|
557
|
+
def unmerge(self):
|
|
558
|
+
if self.r > 0 and self.merged:
|
|
559
|
+
if self.dora:
|
|
560
|
+
print("Warning: DoRA weights cannot be unmerged exactly. Original weights are lost.")
|
|
561
|
+
else:
|
|
562
|
+
weight_b = self.lora_b.weight.squeeze(2)
|
|
563
|
+
weight_a = self.lora_a.weight
|
|
564
|
+
delta_w = torch.einsum('ij, jkn -> ikn', weight_b, weight_a) * self.scaling
|
|
565
|
+
if self.gate: delta_w *= self.lora_gate
|
|
566
|
+
self.original_layer.weight.data -= delta_w
|
|
567
|
+
|
|
568
|
+
self.merged = False
|
|
569
|
+
|
|
570
|
+
def _forward_impl(self, x: torch.Tensor):
|
|
571
|
+
if self.r > 0 and self.merged:
|
|
572
|
+
return self.original_layer(x)
|
|
573
|
+
|
|
574
|
+
if self.dora and self.r > 0:
|
|
575
|
+
weight_b = self.lora_b.weight.squeeze(2)
|
|
576
|
+
weight_a = self.lora_a.weight
|
|
577
|
+
delta_w = torch.einsum('ij, jkn -> ikn', weight_b, weight_a) * self.scaling
|
|
578
|
+
if self.gate: delta_w *= self.lora_gate
|
|
579
|
+
|
|
580
|
+
weight = self.original_layer.weight + delta_w
|
|
581
|
+
norm = weight.norm(p=2, dim=(1, 2), keepdim=True)
|
|
582
|
+
weight = (weight / (norm + 1e-6)) * self.dora_m
|
|
583
|
+
|
|
584
|
+
return F.conv1d(
|
|
585
|
+
x, weight.to(x.dtype), self.original_layer.bias,
|
|
586
|
+
self.stride, self.padding, self.dilation, self.groups
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
result = self.original_layer(x)
|
|
590
|
+
if self.r > 0:
|
|
591
|
+
x = self.lora_dropout(x)
|
|
592
|
+
lora_out = self.lora_b(self.lora_a(x)) * self.scaling
|
|
593
|
+
if self.gate: lora_out *= self.lora_gate
|
|
594
|
+
result += lora_out
|
|
595
|
+
return result
|
|
596
|
+
|
|
597
|
+
def forward(self, x: torch.Tensor):
|
|
598
|
+
if self.gradient_checkpointing and self.training:
|
|
599
|
+
if x.requires_grad:
|
|
600
|
+
return self.checkpoint(self._forward_impl, x)
|
|
601
|
+
else:
|
|
602
|
+
dummy = torch.tensor(0.0, requires_grad=True, device=x.device)
|
|
603
|
+
return self.checkpoint(lambda d, x: self._forward_impl(x), dummy, x)
|
|
604
|
+
return self._forward_impl(x)
|
|
605
|
+
|
|
606
|
+
def __repr__(self):
|
|
607
|
+
prefix = 'Gated' if self.gate else ''
|
|
608
|
+
suffix = 'DoRA' if self.dora else 'LoRA'
|
|
609
|
+
return f'{self.__class__.__name__}(type={prefix}{suffix}, in={self.in_channels}, out={self.out_channels}, kernel={self.kernel_size}, r={self.r}, merged={self.merged})'
|
|
610
|
+
|
|
611
|
+
@register_model()
|
|
612
|
+
class EmbeddingLoRA(BaseBlock):
|
|
613
|
+
'''实现 Embedding 层的 LoRA (Low-Rank Adaptation)。
|
|
614
|
+
|
|
615
|
+
通过注入低秩矩阵来适应 Embedding 权重。
|
|
616
|
+
计算公式: h = W_0[idx] + (A[idx] @ B.T) * scaling
|
|
617
|
+
|
|
618
|
+
Attributes:
|
|
619
|
+
original_layer (nn.Embedding): 原始的 Embedding 层。
|
|
620
|
+
r (int): LoRA 的秩。
|
|
621
|
+
lora_alpha (int): LoRA 的缩放系数。
|
|
622
|
+
scaling (float): 实际缩放比例 (lora_alpha / r)。
|
|
623
|
+
gate (bool): 是否使用 Gated LoRA。
|
|
624
|
+
lora_gate (nn.Parameter): 门控参数。
|
|
625
|
+
dora (bool): 是否使用 DoRA。
|
|
626
|
+
dora_m (nn.Parameter): DoRA 的幅值向量。
|
|
627
|
+
merged (bool): 权重是否已合并。
|
|
628
|
+
lora_a (nn.Embedding): 降维 Embedding 层 (V, r)。
|
|
629
|
+
lora_b (nn.Linear): 升维 Linear 层 (r, D)。
|
|
630
|
+
'''
|
|
631
|
+
def __init__(
|
|
632
|
+
self,
|
|
633
|
+
original_layer: nn.Embedding,
|
|
634
|
+
r: int = 8,
|
|
635
|
+
lora_alpha: int = 16,
|
|
636
|
+
merge_weights: bool = False,
|
|
637
|
+
gate: bool = False,
|
|
638
|
+
dora: bool = False,
|
|
639
|
+
gradient_checkpointing: bool = False
|
|
640
|
+
):
|
|
641
|
+
'''初始化 EmbeddingLoRA。
|
|
642
|
+
|
|
643
|
+
Args:
|
|
644
|
+
original_layer (nn.Embedding): 原始的 Embedding 层。
|
|
645
|
+
r (int): LoRA 的秩。默认为 8。
|
|
646
|
+
lora_alpha (int): LoRA 的缩放系数。默认为 16。
|
|
647
|
+
merge_weights (bool): 初始化时是否将 LoRA 权重合并到原始权重中。默认为 False。
|
|
648
|
+
gate (bool): 是否使用 Gated LoRA。默认为 False。
|
|
649
|
+
dora (bool): 是否使用 DoRA。默认为 False。
|
|
650
|
+
gradient_checkpointing (bool): 是否使用梯度检查点。默认为 False。
|
|
651
|
+
'''
|
|
652
|
+
super().__init__()
|
|
653
|
+
self.gradient_checkpointing = gradient_checkpointing
|
|
654
|
+
self.original_layer = original_layer
|
|
655
|
+
self.num_embeddings = original_layer.num_embeddings
|
|
656
|
+
self.embedding_dim = original_layer.embedding_dim
|
|
657
|
+
self.padding_idx = original_layer.padding_idx
|
|
658
|
+
|
|
659
|
+
self.original_layer.weight.requires_grad = False
|
|
660
|
+
|
|
661
|
+
self.r = r
|
|
662
|
+
self.lora_alpha = lora_alpha
|
|
663
|
+
self.scaling = lora_alpha / r
|
|
664
|
+
self.merged = False
|
|
665
|
+
self.gate = gate
|
|
666
|
+
|
|
667
|
+
if r > 0:
|
|
668
|
+
self.lora_gate = nn.Parameter(torch.tensor([1.0])) if gate else None
|
|
669
|
+
# lora_a: (num_embeddings, r)
|
|
670
|
+
self.lora_a = nn.Embedding(
|
|
671
|
+
self.num_embeddings, r,
|
|
672
|
+
padding_idx=self.padding_idx
|
|
673
|
+
)
|
|
674
|
+
# lora_b: (r, embedding_dim)
|
|
675
|
+
self.lora_b = nn.Linear(r, self.embedding_dim, bias=False)
|
|
676
|
+
else:
|
|
677
|
+
self.lora_a = None
|
|
678
|
+
self.lora_b = None
|
|
679
|
+
|
|
680
|
+
self.reset_parameters()
|
|
681
|
+
|
|
682
|
+
self.dora = dora
|
|
683
|
+
if dora and r > 0:
|
|
684
|
+
# Embedding weight: (V, D) -> norm dim=1
|
|
685
|
+
self.dora_m = nn.Parameter(
|
|
686
|
+
self.original_layer.weight.norm(p=2, dim=1, keepdim=True)
|
|
687
|
+
)
|
|
688
|
+
else:
|
|
689
|
+
self.dora_m = None
|
|
690
|
+
|
|
691
|
+
if hasattr(self.original_layer, 'weight'):
|
|
692
|
+
self.to(self.original_layer.weight.device)
|
|
693
|
+
|
|
694
|
+
if merge_weights:
|
|
695
|
+
self.merge()
|
|
696
|
+
|
|
697
|
+
def reset_parameters(self):
|
|
698
|
+
if self.r > 0:
|
|
699
|
+
nn.init.zeros_(self.lora_a.weight)
|
|
700
|
+
nn.init.normal_(self.lora_b.weight, mean=0.0, std=0.02)
|
|
701
|
+
|
|
702
|
+
def merge(self):
|
|
703
|
+
if self.r > 0 and not self.merged:
|
|
704
|
+
weight_b = self.lora_b.weight # (D, r)
|
|
705
|
+
weight_a = self.lora_a.weight # (V, r)
|
|
706
|
+
|
|
707
|
+
delta_w = (weight_a @ weight_b.T) * self.scaling
|
|
708
|
+
if self.gate: delta_w *= self.lora_gate
|
|
709
|
+
|
|
710
|
+
if self.dora:
|
|
711
|
+
weight = self.original_layer.weight + delta_w
|
|
712
|
+
norm = weight.norm(p=2, dim=1, keepdim=True)
|
|
713
|
+
weight = (weight / (norm + 1e-6)) * self.dora_m
|
|
714
|
+
self.original_layer.weight.data = weight.to(self.original_layer.weight.dtype)
|
|
715
|
+
else:
|
|
716
|
+
self.original_layer.weight.data += delta_w
|
|
717
|
+
|
|
718
|
+
self.merged = True
|
|
719
|
+
|
|
720
|
+
def unmerge(self):
|
|
721
|
+
if self.r > 0 and self.merged:
|
|
722
|
+
if self.dora:
|
|
723
|
+
print("Warning: DoRA weights cannot be unmerged exactly. Original weights are lost.")
|
|
724
|
+
else:
|
|
725
|
+
weight_b = self.lora_b.weight
|
|
726
|
+
weight_a = self.lora_a.weight
|
|
727
|
+
delta_w = (weight_a @ weight_b.T) * self.scaling
|
|
728
|
+
if self.gate: delta_w *= self.lora_gate
|
|
729
|
+
self.original_layer.weight.data -= delta_w
|
|
730
|
+
|
|
731
|
+
self.merged = False
|
|
732
|
+
|
|
733
|
+
def _forward_impl(self, x: torch.Tensor):
|
|
734
|
+
if self.r > 0 and self.merged:
|
|
735
|
+
return self.original_layer(x)
|
|
736
|
+
|
|
737
|
+
if self.dora and self.r > 0:
|
|
738
|
+
# DoRA embedding
|
|
739
|
+
weight_b = self.lora_b.weight
|
|
740
|
+
weight_a = self.lora_a.weight
|
|
741
|
+
delta_w = (weight_a @ weight_b.T) * self.scaling
|
|
742
|
+
if self.gate: delta_w *= self.lora_gate
|
|
743
|
+
|
|
744
|
+
weight = self.original_layer.weight + delta_w
|
|
745
|
+
norm = weight.norm(p=2, dim=1, keepdim=True)
|
|
746
|
+
weight = (weight / (norm + 1e-6)) * self.dora_m
|
|
747
|
+
|
|
748
|
+
return F.embedding(
|
|
749
|
+
x, weight.to(x.dtype if x.dtype.is_floating_point else self.original_layer.weight.dtype), self.padding_idx,
|
|
750
|
+
self.original_layer.max_norm, self.original_layer.norm_type,
|
|
751
|
+
self.original_layer.scale_grad_by_freq, self.original_layer.sparse
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
result = self.original_layer(x)
|
|
755
|
+
|
|
756
|
+
if self.r > 0:
|
|
757
|
+
# A(x): Look up -> (Batch, Len, r)
|
|
758
|
+
a_out = self.lora_a(x)
|
|
759
|
+
# B(A(x)): Linear -> (Batch, Len, Dim)
|
|
760
|
+
lora_out = self.lora_b(a_out) * self.scaling
|
|
761
|
+
if self.gate: lora_out *= self.lora_gate
|
|
762
|
+
result += lora_out
|
|
763
|
+
|
|
764
|
+
return result
|
|
765
|
+
|
|
766
|
+
def forward(self, x: torch.Tensor):
|
|
767
|
+
if self.gradient_checkpointing and self.training:
|
|
768
|
+
# Embedding inputs (indices) don't have gradients, so we always use the dummy tensor trick
|
|
769
|
+
dummy = torch.tensor(0.0, requires_grad=True, device=x.device)
|
|
770
|
+
return self.checkpoint(lambda d, x: self._forward_impl(x), dummy, x)
|
|
771
|
+
return self._forward_impl(x)
|
|
772
|
+
|
|
773
|
+
def __repr__(self):
|
|
774
|
+
prefix = 'Gated' if self.gate else ''
|
|
775
|
+
suffix = 'DoRA' if self.dora else 'LoRA'
|
|
776
|
+
return f'{self.__class__.__name__}(type={prefix}{suffix}, num={self.num_embeddings}, dim={self.embedding_dim}, r={self.r}, merged={self.merged})'
|