orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. orbit/__init__.py +3 -1
  2. orbit/callback.py +4 -3
  3. orbit/dataset/__init__.py +1 -0
  4. orbit/dataset/cogn.py +138 -0
  5. orbit/dataset/data/cogn_en.jsonl +45 -0
  6. orbit/dataset/data/cogn_zh.jsonl +113 -0
  7. orbit/engine.py +210 -146
  8. orbit/kit/__init__.py +2 -0
  9. orbit/kit/interface.py +154 -0
  10. orbit/kit/wrapper.py +157 -0
  11. orbit/model/__init__.py +5 -0
  12. orbit/model/base.py +125 -0
  13. orbit/model/block/__init__.py +34 -0
  14. orbit/model/block/attention.py +265 -0
  15. orbit/model/block/bio.py +537 -0
  16. orbit/model/block/codebook.py +122 -0
  17. orbit/model/block/conv.py +505 -0
  18. orbit/model/block/embedding.py +252 -0
  19. orbit/model/block/film.py +176 -0
  20. orbit/model/block/fusion.py +335 -0
  21. orbit/model/block/gate.py +334 -0
  22. orbit/model/block/lora.py +776 -0
  23. orbit/model/block/mlp.py +68 -0
  24. orbit/model/block/moe.py +94 -0
  25. orbit/model/block/tcn.py +99 -0
  26. orbit/model/config.py +62 -0
  27. orbit/model/kit/__init__.py +6 -0
  28. orbit/model/kit/discriminator.py +46 -0
  29. orbit/model/kit/losses.py +193 -0
  30. orbit/model/motif/__init__.py +0 -0
  31. orbit/model/motif/vision/__init__.py +0 -0
  32. orbit/model/motif/vision/v1.py +645 -0
  33. orbit/model/registry.py +53 -0
  34. orbit/optim/__init__.py +2 -2
  35. orbit/optim/sam.py +10 -3
  36. orbit/plugin/__init__.py +12 -8
  37. orbit/plugin/board.py +1 -2
  38. orbit/plugin/checkpoint.py +137 -62
  39. orbit/plugin/classification.py +2 -2
  40. orbit/plugin/display_model.py +1 -2
  41. orbit/plugin/early_stopping.py +1 -2
  42. orbit/plugin/ema.py +1 -2
  43. orbit/plugin/gradient_accumulation.py +1 -2
  44. orbit/plugin/lora.py +346 -0
  45. orbit/plugin/memory_estimator.py +1 -2
  46. orbit/plugin/warmup.py +1 -2
  47. orbit/utils/__init__.py +24 -1
  48. orbit/utils/cuda.py +10 -0
  49. orbit/utils/freeze.py +61 -17
  50. orbit/utils/image.py +164 -0
  51. orbit/utils/initialization.py +184 -94
  52. orbit/utils/layer_io.py +66 -7
  53. orbit/utils/lora.py +480 -0
  54. orbit/utils/moe.py +55 -0
  55. orbit/utils/seed.py +3 -19
  56. orbit/utils/sft.py +93 -0
  57. orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
  58. orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
  59. orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
  60. orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
  61. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
  62. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,776 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ from orbit.model import BaseBlock, register_model
7
+
8
+ @register_model()
9
+ class LinearLoRA(BaseBlock):
10
+ '''实现 Linear 层的 LoRA (Low-Rank Adaptation)。
11
+
12
+ LoRA 通过注入可训练的低秩矩阵来适应预训练权重,同时冻结原始权重。
13
+ 计算公式: h = W_0 x + B A x * scaling
14
+
15
+ Attributes:
16
+ original_layer (nn.Linear): 原始的 Linear 层。
17
+ r (int): LoRA 的秩。
18
+ lora_alpha (int): LoRA 的缩放系数。
19
+ scaling (float): 实际缩放比例 (lora_alpha / r)。
20
+ gate (bool): 是否使用 Gated LoRA。
21
+ lora_gate (nn.Parameter): 门控参数。
22
+ dora (bool): 是否使用 DoRA。
23
+ dora_m (nn.Parameter): DoRA 的幅值向量。
24
+ merged (bool): 权重是否已合并。
25
+ lora_a (nn.Parameter): 降维矩阵 A。
26
+ lora_b (nn.Parameter): 升维矩阵 B。
27
+ '''
28
+ def __init__(
29
+ self,
30
+ original_layer: nn.Linear,
31
+ r: int = 8,
32
+ lora_alpha: int = 16,
33
+ lora_dropout: float = 0.05,
34
+ merge_weights: bool = False,
35
+ gate: bool = False,
36
+ dora: bool = False,
37
+ gradient_checkpointing: bool = False
38
+ ):
39
+ '''初始化 LinearLoRA。
40
+
41
+ Args:
42
+ original_layer (nn.Linear): 原始的 Linear 层。
43
+ r (int): LoRA 的秩。默认为 8。
44
+ lora_alpha (int): LoRA 的缩放系数。默认为 16。
45
+ lora_dropout (float): Dropout 概率。默认为 0.05。
46
+ merge_weights (bool): 初始化时是否将 LoRA 权重合并到原始权重中。默认为 False。
47
+ gate (bool): 是否使用 Gated LoRA。默认为 False。
48
+ dora (bool): 是否使用 DoRA。默认为 False。
49
+ gradient_checkpointing (bool): 是否使用梯度检查点。默认为 False。
50
+ '''
51
+ super().__init__()
52
+ self.gradient_checkpointing = gradient_checkpointing
53
+
54
+ self.in_features = original_layer.in_features
55
+ self.out_features = original_layer.out_features
56
+
57
+ self.original_layer = original_layer
58
+ for p in self.original_layer.parameters():
59
+ p.requires_grad = False
60
+
61
+ self.r = r
62
+ self.lora_alpha = lora_alpha
63
+ self.scaling = lora_alpha / r
64
+ self.merged = False
65
+ self.gate = gate
66
+
67
+ if r > 0:
68
+ self.lora_gate = nn.Parameter(torch.tensor([1.0])) if gate else None
69
+ self.lora_a = nn.Parameter(torch.zeros((r, self.in_features)))
70
+ self.lora_b = nn.Parameter(torch.zeros((self.out_features, r)))
71
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
72
+ else:
73
+ self.lora_a = None
74
+ self.lora_b = None
75
+ self.lora_dropout = None
76
+
77
+ self.reset_parameters()
78
+
79
+ self.dora = dora
80
+ if dora and r > 0:
81
+ self.dora_m = nn.Parameter(self.original_layer.weight.norm(p=2, dim=0, keepdim=True))
82
+ else:
83
+ self.dora_m = None
84
+
85
+ # 确保 LoRA 参数与原始层在同一设备上
86
+ if hasattr(self.original_layer, 'weight'):
87
+ self.to(self.original_layer.weight.device)
88
+
89
+ if merge_weights: self.merge()
90
+
91
+ def reset_parameters(self):
92
+ '''重置 LoRA 参数。
93
+
94
+ A 矩阵使用 Kaiming Uniform 初始化,B 矩阵初始化为零。
95
+ 这样可以确保初始状态下 LoRA 分支的输出为零,不影响模型原有输出。
96
+ '''
97
+ if self.r > 0:
98
+ nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5))
99
+ nn.init.zeros_(self.lora_b)
100
+
101
+ def train(self, mode: bool = True):
102
+ '''设置训练模式。
103
+
104
+ 如果进入训练模式,确保权重未合并。
105
+
106
+ Args:
107
+ mode (bool): 是否为训练模式。
108
+ '''
109
+ super().train(mode)
110
+ if not mode: return
111
+ if self.merged: self.unmerge()
112
+
113
+ def merge(self):
114
+ '''将 LoRA 权重合并到原始层权重中。
115
+
116
+ 用于推理加速。
117
+ DoRA: 合并后无法恢复原始权重(除非存储原始权重副本,但这违背了LoRA节省显存的初衷)。
118
+ '''
119
+ if self.r > 0 and not self.merged:
120
+ if self.dora:
121
+ # Calculate full weight W' = W0 + BA * scaling
122
+ delta_w = (self.lora_b @ self.lora_a) * self.scaling
123
+ if self.gate: delta_w *= self.lora_gate
124
+ weight = self.original_layer.weight + delta_w
125
+
126
+ # Normalize and scale: W_final = m * W' / ||W'||
127
+ norm = weight.norm(p=2, dim=1, keepdim=True)
128
+ weight = (weight / (norm + 1e-6)) * self.dora_m
129
+
130
+ # Update original weight (Destructive!)
131
+ self.original_layer.weight.data = weight.to(self.original_layer.weight.dtype)
132
+ else:
133
+ # W_new = W_old + B @ A * scaling
134
+ delta_w = (self.lora_b @ self.lora_a) * self.scaling
135
+ if self.gate: delta_w *= self.lora_gate
136
+ self.original_layer.weight.data += delta_w.to(self.original_layer.weight.dtype)
137
+
138
+ self.merged = True
139
+
140
+ def unmerge(self):
141
+ '''从原始权重中减去 LoRA 权重。
142
+
143
+ 用于恢复原始权重或继续训练。
144
+ 注意:DoRA 模式下不支持 unmerge。
145
+ '''
146
+ if self.r > 0 and self.merged:
147
+ if self.dora:
148
+ print("Warning: DoRA weights cannot be unmerged exactly. Original weights are lost.")
149
+ pass
150
+ else:
151
+ delta_w = (self.lora_b @ self.lora_a) * self.scaling
152
+ if self.gate: delta_w *= self.lora_gate
153
+ self.original_layer.weight.data -= delta_w
154
+
155
+ self.merged = False
156
+
157
+ def _forward_impl(self, x: torch.Tensor):
158
+ if self.r > 0 and self.merged:
159
+ return self.original_layer(x)
160
+
161
+ if self.dora and self.r > 0:
162
+ # DoRA: W_final = m * (W0 + BA) / ||W0 + BA||
163
+ delta_w = (self.lora_b @ self.lora_a) * self.scaling
164
+ if self.gate: delta_w *= self.lora_gate
165
+
166
+ # Reconstruct full weight for calculation
167
+ weight = self.original_layer.weight + delta_w
168
+ norm = weight.norm(p=2, dim=1, keepdim=True)
169
+ weight = (weight / (norm + 1e-6)) * self.dora_m
170
+
171
+ return F.linear(x, weight.to(x.dtype), self.original_layer.bias)
172
+
173
+ result = self.original_layer(x)
174
+
175
+ if self.r > 0:
176
+ # x shape: (batch, ..., in)
177
+ # lora_a shape: (r, in) -> x @ A.T -> (batch, ..., r)
178
+ # lora_b shape: (out, r) -> result @ B.T -> (batch, ..., out)
179
+ x_dropped = self.lora_dropout(x)
180
+ lora_out = (x_dropped @ self.lora_a.transpose(0, 1) @ self.lora_b.transpose(0, 1)) * self.scaling
181
+ if self.gate: lora_out *= self.lora_gate
182
+ result += lora_out
183
+
184
+ return result
185
+
186
+ def forward(self, x: torch.Tensor):
187
+ '''前向传播。
188
+
189
+ Args:
190
+ x (torch.Tensor): 输入张量。
191
+
192
+ Returns:
193
+ torch.Tensor: 输出张量。
194
+ '''
195
+ if self.gradient_checkpointing and self.training:
196
+ if x.requires_grad:
197
+ return self.checkpoint(self._forward_impl, x)
198
+ else:
199
+ dummy = torch.tensor(0.0, requires_grad=True, device=x.device)
200
+ return self.checkpoint(lambda d, x: self._forward_impl(x), dummy, x)
201
+ return self._forward_impl(x)
202
+
203
+ def __repr__(self):
204
+ prefix = 'Gated' if self.gate else ''
205
+ suffix = 'DoRA' if self.dora else 'LoRA'
206
+ return f'{self.__class__.__name__}(type={prefix}{suffix}, in_features={self.in_features}, out_features={self.out_features}, r={self.r}, merged={self.merged})'
207
+
208
+ @register_model()
209
+ class Conv2dLoRA(BaseBlock):
210
+ '''实现 Conv2d 层的 LoRA (Low-Rank Adaptation)。
211
+
212
+ 使用两个连续的卷积层模拟低秩矩阵分解:
213
+ 1. A 层: 降低通道数到 r,保持 kernel_size。
214
+ 2. B 层: 恢复通道数,使用 1x1 kernel。
215
+
216
+ Attributes:
217
+ original_layer (nn.Conv2d): 原始的 Conv2d 层。
218
+ r (int): LoRA 的秩。
219
+ lora_alpha (int): LoRA 的缩放系数。
220
+ scaling (float): 实际缩放比例 (lora_alpha / r)。
221
+ gate (bool): 是否使用 Gated LoRA。
222
+ lora_gate (nn.Parameter): 门控参数。
223
+ dora (bool): 是否使用 DoRA。
224
+ dora_m (nn.Parameter): DoRA 的幅值向量。
225
+ merged (bool): 权重是否已合并。
226
+ lora_a (nn.Conv2d): 降维卷积层。
227
+ lora_b (nn.Conv2d): 升维卷积层 (1x1)。
228
+ '''
229
+ def __init__(
230
+ self,
231
+ original_layer: nn.Conv2d,
232
+ r: int = 8,
233
+ lora_alpha: int = 16,
234
+ lora_dropout: float = 0.05,
235
+ merge_weights: bool = False,
236
+ gate: bool = False,
237
+ dora: bool = False,
238
+ gradient_checkpointing: bool = False
239
+ ):
240
+ '''初始化 Conv2dLoRA。
241
+
242
+ Args:
243
+ original_layer (nn.Conv2d): 原始的 Conv2d 层。
244
+ r (int): LoRA 的秩。默认为 8。
245
+ lora_alpha (int): LoRA 的缩放系数。默认为 16。
246
+ lora_dropout (float): Dropout 概率。默认为 0.05。
247
+ merge_weights (bool): 初始化时是否将 LoRA 权重合并到原始权重中。默认为 False。
248
+ gate (bool): 是否使用 Gated LoRA。默认为 False。
249
+ dora (bool): 是否使用 DoRA。默认为 False。
250
+ gradient_checkpointing (bool): 是否使用梯度检查点。默认为 False。
251
+ '''
252
+ super().__init__()
253
+ self.gradient_checkpointing = gradient_checkpointing
254
+ self.original_layer = original_layer
255
+ self.in_channels = original_layer.in_channels
256
+ self.out_channels = original_layer.out_channels
257
+ self.kernel_size = original_layer.kernel_size
258
+ self.stride = original_layer.stride
259
+ self.padding = original_layer.padding
260
+ self.dilation = original_layer.dilation
261
+ self.groups = original_layer.groups
262
+
263
+ for p in self.original_layer.parameters():
264
+ p.requires_grad = False
265
+
266
+ self.r = r
267
+ self.lora_alpha = lora_alpha
268
+ self.scaling = lora_alpha / r
269
+ self.merged = False
270
+ self.gate = gate
271
+
272
+ if r > 0:
273
+ self.lora_gate = nn.Parameter(torch.tensor([1.0])) if gate else None
274
+ self.lora_a = nn.Conv2d(
275
+ self.in_channels, r,
276
+ kernel_size=self.kernel_size,
277
+ stride=self.stride,
278
+ padding=self.padding,
279
+ dilation=self.dilation,
280
+ groups=self.groups,
281
+ bias=False
282
+ )
283
+
284
+ self.lora_b = nn.Conv2d(
285
+ r, self.out_channels,
286
+ kernel_size=1,
287
+ stride=1,
288
+ padding=0,
289
+ bias=False
290
+ )
291
+
292
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
293
+ else:
294
+ self.lora_a = None
295
+ self.lora_b = None
296
+
297
+ self.reset_parameters()
298
+
299
+ self.dora = dora
300
+ if dora and r > 0:
301
+ # Conv2d weight: (out, in, k, k) -> norm dim=(1,2,3) for each output channel
302
+ self.dora_m = nn.Parameter(
303
+ self.original_layer.weight.norm(p=2, dim=(1, 2, 3), keepdim=True)
304
+ )
305
+ else:
306
+ self.dora_m = None
307
+
308
+ if hasattr(self.original_layer, 'weight'):
309
+ self.to(self.original_layer.weight.device)
310
+
311
+ if merge_weights: self.merge()
312
+
313
+ def reset_parameters(self):
314
+ '''重置 LoRA 参数。
315
+
316
+ A 卷积层使用 Kaiming Uniform 初始化,B 卷积层初始化为零。
317
+ '''
318
+ if self.r > 0:
319
+ # A: Kaiming 初始化
320
+ nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
321
+ # B: 0 初始化
322
+ nn.init.zeros_(self.lora_b.weight)
323
+
324
+ def train(self, mode: bool = True):
325
+ '''设置训练模式。
326
+
327
+ 如果进入训练模式,确保权重未合并。
328
+
329
+ Args:
330
+ mode (bool): 是否为训练模式。
331
+ '''
332
+ super().train(mode)
333
+ if mode and self.merged: self.unmerge()
334
+
335
+ def merge(self):
336
+ '''将 LoRA 权重合并到原始卷积层权重中。
337
+
338
+ 使用 einsum 计算 LoRA 分支的等效卷积核并加到原始权重上。
339
+ '''
340
+ if self.r > 0 and not self.merged:
341
+ weight_b = self.lora_b.weight.squeeze(3).squeeze(2) # (out, r)
342
+ weight_a = self.lora_a.weight # (r, in, k, k)
343
+
344
+ # i: out_channels, j: r, k: in_channels, m, n: kernel dims
345
+ delta_w = torch.einsum('ij, jkmn -> ikmn', weight_b, weight_a) * self.scaling
346
+ if self.gate: delta_w *= self.lora_gate
347
+
348
+ if self.dora:
349
+ weight = self.original_layer.weight + delta_w
350
+ norm = weight.norm(p=2, dim=(1, 2, 3), keepdim=True)
351
+ weight = (weight / (norm + 1e-6)) * self.dora_m
352
+ self.original_layer.weight.data = weight.to(self.original_layer.weight.dtype)
353
+ else:
354
+ self.original_layer.weight.data += delta_w
355
+
356
+ self.merged = True
357
+
358
+ def unmerge(self):
359
+ '''从原始权重中减去 LoRA 权重。'''
360
+ if self.r > 0 and self.merged:
361
+ if self.dora:
362
+ print("Warning: DoRA weights cannot be unmerged exactly. Original weights are lost.")
363
+ else:
364
+ weight_b = self.lora_b.weight.squeeze(3).squeeze(2)
365
+ weight_a = self.lora_a.weight
366
+ delta_w = torch.einsum('ij, jkmn -> ikmn', weight_b, weight_a) * self.scaling
367
+ if self.gate: delta_w *= self.lora_gate
368
+ self.original_layer.weight.data -= delta_w
369
+
370
+ self.merged = False
371
+
372
+ def _forward_impl(self, x: torch.Tensor):
373
+ if self.r > 0 and self.merged:
374
+ return self.original_layer(x)
375
+
376
+ if self.dora and self.r > 0:
377
+ weight_b = self.lora_b.weight.squeeze(3).squeeze(2) # (out, r)
378
+ weight_a = self.lora_a.weight # (r, in, k, k)
379
+ delta_w = torch.einsum('ij, jkmn -> ikmn', weight_b, weight_a) * self.scaling
380
+ if self.gate: delta_w *= self.lora_gate
381
+
382
+ weight = self.original_layer.weight + delta_w
383
+ norm = weight.norm(p=2, dim=(1, 2, 3), keepdim=True)
384
+ weight = (weight / (norm + 1e-6)) * self.dora_m
385
+
386
+ return F.conv2d(
387
+ x, weight.to(x.dtype), self.original_layer.bias,
388
+ self.stride, self.padding, self.dilation, self.groups
389
+ )
390
+
391
+ result = self.original_layer(x)
392
+
393
+ if self.r > 0:
394
+ x_dropped = self.lora_dropout(x)
395
+ # Input -> Conv(in, r)[spatial] -> Conv(r, out)[1x1]
396
+ lora_out = self.lora_b(self.lora_a(x_dropped)) * self.scaling
397
+ if self.gate: lora_out *= self.lora_gate
398
+ result += lora_out
399
+
400
+ return result
401
+
402
+ def forward(self, x: torch.Tensor):
403
+ '''前向传播。
404
+
405
+ Args:
406
+ x (torch.Tensor): 输入张量。
407
+
408
+ Returns:
409
+ torch.Tensor: 输出张量。
410
+ '''
411
+ if self.gradient_checkpointing and self.training:
412
+ if x.requires_grad:
413
+ return self.checkpoint(self._forward_impl, x)
414
+ else:
415
+ dummy = torch.tensor(0.0, requires_grad=True, device=x.device)
416
+ return self.checkpoint(lambda d, x: self._forward_impl(x), dummy, x)
417
+ return self._forward_impl(x)
418
+
419
+ def __repr__(self):
420
+ prefix = 'Gated' if self.gate else ''
421
+ suffix = 'DoRA' if self.dora else 'LoRA'
422
+ return f'{self.__class__.__name__}(type={prefix}{suffix}, in_channels={self.in_channels}, out_channels={self.out_channels}, kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}, dilation={self.dilation}, groups={self.groups}, r={self.r}, merged={self.merged})'
423
+
424
+ @register_model()
425
+ class Conv1dLoRA(BaseBlock):
426
+ '''实现 Conv1d 层的 LoRA (Low-Rank Adaptation)。
427
+
428
+ 使用两个连续的卷积层模拟低秩矩阵分解:
429
+ 1. A 层: 降低通道数到 r,保持 kernel_size。
430
+ 2. B 层: 恢复通道数,使用 1x1 kernel。
431
+
432
+ Attributes:
433
+ original_layer (nn.Conv1d): 原始的 Conv1d 层。
434
+ r (int): LoRA 的秩。
435
+ lora_alpha (int): LoRA 的缩放系数。
436
+ scaling (float): 实际缩放比例 (lora_alpha / r)。
437
+ gate (bool): 是否使用 Gated LoRA。
438
+ lora_gate (nn.Parameter): 门控参数。
439
+ dora (bool): 是否使用 DoRA。
440
+ dora_m (nn.Parameter): DoRA 的幅值向量。
441
+ merged (bool): 权重是否已合并。
442
+ lora_a (nn.Conv1d): 降维卷积层。
443
+ lora_b (nn.Conv1d): 升维卷积层 (1x1)。
444
+ '''
445
+ def __init__(
446
+ self,
447
+ original_layer: nn.Conv1d,
448
+ r: int = 8,
449
+ lora_alpha: int = 16,
450
+ lora_dropout: float = 0.05,
451
+ merge_weights: bool = False,
452
+ gate: bool = False,
453
+ dora: bool = False,
454
+ gradient_checkpointing: bool = False
455
+ ):
456
+ '''初始化 Conv1dLoRA。
457
+
458
+ Args:
459
+ original_layer (nn.Conv1d): 原始的 Conv1d 层。
460
+ r (int): LoRA 的秩。默认为 8。
461
+ lora_alpha (int): LoRA 的缩放系数。默认为 16。
462
+ lora_dropout (float): Dropout 概率。默认为 0.05。
463
+ merge_weights (bool): 初始化时是否将 LoRA 权重合并到原始权重中。默认为 False。
464
+ gate (bool): 是否使用 Gated LoRA。默认为 False。
465
+ dora (bool): 是否使用 DoRA。默认为 False。
466
+ gradient_checkpointing (bool): 是否使用梯度检查点。默认为 False。
467
+ '''
468
+ super().__init__()
469
+ self.gradient_checkpointing = gradient_checkpointing
470
+ self.original_layer = original_layer
471
+ self.in_channels = original_layer.in_channels
472
+ self.out_channels = original_layer.out_channels
473
+ self.kernel_size = original_layer.kernel_size[0] # Conv1d kernel_size 是 tuple
474
+ self.stride = original_layer.stride[0]
475
+ self.padding = original_layer.padding[0]
476
+ self.dilation = original_layer.dilation[0]
477
+ self.groups = original_layer.groups
478
+
479
+ # 冻结原层
480
+ for p in self.original_layer.parameters():
481
+ p.requires_grad = False
482
+
483
+ self.r = r
484
+ self.lora_alpha = lora_alpha
485
+ self.scaling = lora_alpha / r
486
+ self.merged = False
487
+ self.gate = gate
488
+
489
+ if r > 0:
490
+ self.lora_gate = nn.Parameter(torch.tensor([1.0])) if gate else None
491
+ # A: 降维 + 空间(时序)卷积
492
+ self.lora_a = nn.Conv1d(
493
+ self.in_channels, r,
494
+ kernel_size=self.kernel_size,
495
+ stride=self.stride,
496
+ padding=self.padding,
497
+ dilation=self.dilation,
498
+ groups=self.groups,
499
+ bias=False
500
+ )
501
+ # B: 升维 + 点卷积 (kernel=1)
502
+ self.lora_b = nn.Conv1d(
503
+ r, self.out_channels,
504
+ kernel_size=1,
505
+ stride=1,
506
+ padding=0,
507
+ bias=False
508
+ )
509
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
510
+ else:
511
+ self.lora_a = None
512
+ self.lora_b = None
513
+
514
+ self.reset_parameters()
515
+
516
+ self.dora = dora
517
+ if dora and r > 0:
518
+ # Conv1d weight: (out, in, k) -> norm dim=(1,2)
519
+ self.dora_m = nn.Parameter(
520
+ self.original_layer.weight.norm(p=2, dim=(1, 2), keepdim=True)
521
+ )
522
+ else:
523
+ self.dora_m = None
524
+
525
+ if hasattr(self.original_layer, 'weight'):
526
+ self.to(self.original_layer.weight.device)
527
+
528
+ if merge_weights:
529
+ self.merge()
530
+
531
+ def reset_parameters(self):
532
+ if self.r > 0:
533
+ nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
534
+ nn.init.zeros_(self.lora_b.weight)
535
+
536
+ def merge(self):
537
+ if self.r > 0 and not self.merged:
538
+ # B: (out, r, 1) -> (out, r)
539
+ weight_b = self.lora_b.weight.squeeze(2)
540
+ # A: (r, in, k)
541
+ weight_a = self.lora_a.weight
542
+
543
+ # einsum: ij(out,r), jkn(r,in,k) -> ikn(out,in,k)
544
+ delta_w = torch.einsum('ij, jkn -> ikn', weight_b, weight_a) * self.scaling
545
+ if self.gate: delta_w *= self.lora_gate
546
+
547
+ if self.dora:
548
+ weight = self.original_layer.weight + delta_w
549
+ norm = weight.norm(p=2, dim=(1, 2), keepdim=True)
550
+ weight = (weight / (norm + 1e-6)) * self.dora_m
551
+ self.original_layer.weight.data = weight.to(self.original_layer.weight.dtype)
552
+ else:
553
+ self.original_layer.weight.data += delta_w
554
+
555
+ self.merged = True
556
+
557
+ def unmerge(self):
558
+ if self.r > 0 and self.merged:
559
+ if self.dora:
560
+ print("Warning: DoRA weights cannot be unmerged exactly. Original weights are lost.")
561
+ else:
562
+ weight_b = self.lora_b.weight.squeeze(2)
563
+ weight_a = self.lora_a.weight
564
+ delta_w = torch.einsum('ij, jkn -> ikn', weight_b, weight_a) * self.scaling
565
+ if self.gate: delta_w *= self.lora_gate
566
+ self.original_layer.weight.data -= delta_w
567
+
568
+ self.merged = False
569
+
570
+ def _forward_impl(self, x: torch.Tensor):
571
+ if self.r > 0 and self.merged:
572
+ return self.original_layer(x)
573
+
574
+ if self.dora and self.r > 0:
575
+ weight_b = self.lora_b.weight.squeeze(2)
576
+ weight_a = self.lora_a.weight
577
+ delta_w = torch.einsum('ij, jkn -> ikn', weight_b, weight_a) * self.scaling
578
+ if self.gate: delta_w *= self.lora_gate
579
+
580
+ weight = self.original_layer.weight + delta_w
581
+ norm = weight.norm(p=2, dim=(1, 2), keepdim=True)
582
+ weight = (weight / (norm + 1e-6)) * self.dora_m
583
+
584
+ return F.conv1d(
585
+ x, weight.to(x.dtype), self.original_layer.bias,
586
+ self.stride, self.padding, self.dilation, self.groups
587
+ )
588
+
589
+ result = self.original_layer(x)
590
+ if self.r > 0:
591
+ x = self.lora_dropout(x)
592
+ lora_out = self.lora_b(self.lora_a(x)) * self.scaling
593
+ if self.gate: lora_out *= self.lora_gate
594
+ result += lora_out
595
+ return result
596
+
597
+ def forward(self, x: torch.Tensor):
598
+ if self.gradient_checkpointing and self.training:
599
+ if x.requires_grad:
600
+ return self.checkpoint(self._forward_impl, x)
601
+ else:
602
+ dummy = torch.tensor(0.0, requires_grad=True, device=x.device)
603
+ return self.checkpoint(lambda d, x: self._forward_impl(x), dummy, x)
604
+ return self._forward_impl(x)
605
+
606
+ def __repr__(self):
607
+ prefix = 'Gated' if self.gate else ''
608
+ suffix = 'DoRA' if self.dora else 'LoRA'
609
+ return f'{self.__class__.__name__}(type={prefix}{suffix}, in={self.in_channels}, out={self.out_channels}, kernel={self.kernel_size}, r={self.r}, merged={self.merged})'
610
+
611
+ @register_model()
612
+ class EmbeddingLoRA(BaseBlock):
613
+ '''实现 Embedding 层的 LoRA (Low-Rank Adaptation)。
614
+
615
+ 通过注入低秩矩阵来适应 Embedding 权重。
616
+ 计算公式: h = W_0[idx] + (A[idx] @ B.T) * scaling
617
+
618
+ Attributes:
619
+ original_layer (nn.Embedding): 原始的 Embedding 层。
620
+ r (int): LoRA 的秩。
621
+ lora_alpha (int): LoRA 的缩放系数。
622
+ scaling (float): 实际缩放比例 (lora_alpha / r)。
623
+ gate (bool): 是否使用 Gated LoRA。
624
+ lora_gate (nn.Parameter): 门控参数。
625
+ dora (bool): 是否使用 DoRA。
626
+ dora_m (nn.Parameter): DoRA 的幅值向量。
627
+ merged (bool): 权重是否已合并。
628
+ lora_a (nn.Embedding): 降维 Embedding 层 (V, r)。
629
+ lora_b (nn.Linear): 升维 Linear 层 (r, D)。
630
+ '''
631
+ def __init__(
632
+ self,
633
+ original_layer: nn.Embedding,
634
+ r: int = 8,
635
+ lora_alpha: int = 16,
636
+ merge_weights: bool = False,
637
+ gate: bool = False,
638
+ dora: bool = False,
639
+ gradient_checkpointing: bool = False
640
+ ):
641
+ '''初始化 EmbeddingLoRA。
642
+
643
+ Args:
644
+ original_layer (nn.Embedding): 原始的 Embedding 层。
645
+ r (int): LoRA 的秩。默认为 8。
646
+ lora_alpha (int): LoRA 的缩放系数。默认为 16。
647
+ merge_weights (bool): 初始化时是否将 LoRA 权重合并到原始权重中。默认为 False。
648
+ gate (bool): 是否使用 Gated LoRA。默认为 False。
649
+ dora (bool): 是否使用 DoRA。默认为 False。
650
+ gradient_checkpointing (bool): 是否使用梯度检查点。默认为 False。
651
+ '''
652
+ super().__init__()
653
+ self.gradient_checkpointing = gradient_checkpointing
654
+ self.original_layer = original_layer
655
+ self.num_embeddings = original_layer.num_embeddings
656
+ self.embedding_dim = original_layer.embedding_dim
657
+ self.padding_idx = original_layer.padding_idx
658
+
659
+ self.original_layer.weight.requires_grad = False
660
+
661
+ self.r = r
662
+ self.lora_alpha = lora_alpha
663
+ self.scaling = lora_alpha / r
664
+ self.merged = False
665
+ self.gate = gate
666
+
667
+ if r > 0:
668
+ self.lora_gate = nn.Parameter(torch.tensor([1.0])) if gate else None
669
+ # lora_a: (num_embeddings, r)
670
+ self.lora_a = nn.Embedding(
671
+ self.num_embeddings, r,
672
+ padding_idx=self.padding_idx
673
+ )
674
+ # lora_b: (r, embedding_dim)
675
+ self.lora_b = nn.Linear(r, self.embedding_dim, bias=False)
676
+ else:
677
+ self.lora_a = None
678
+ self.lora_b = None
679
+
680
+ self.reset_parameters()
681
+
682
+ self.dora = dora
683
+ if dora and r > 0:
684
+ # Embedding weight: (V, D) -> norm dim=1
685
+ self.dora_m = nn.Parameter(
686
+ self.original_layer.weight.norm(p=2, dim=1, keepdim=True)
687
+ )
688
+ else:
689
+ self.dora_m = None
690
+
691
+ if hasattr(self.original_layer, 'weight'):
692
+ self.to(self.original_layer.weight.device)
693
+
694
+ if merge_weights:
695
+ self.merge()
696
+
697
+ def reset_parameters(self):
698
+ if self.r > 0:
699
+ nn.init.zeros_(self.lora_a.weight)
700
+ nn.init.normal_(self.lora_b.weight, mean=0.0, std=0.02)
701
+
702
+ def merge(self):
703
+ if self.r > 0 and not self.merged:
704
+ weight_b = self.lora_b.weight # (D, r)
705
+ weight_a = self.lora_a.weight # (V, r)
706
+
707
+ delta_w = (weight_a @ weight_b.T) * self.scaling
708
+ if self.gate: delta_w *= self.lora_gate
709
+
710
+ if self.dora:
711
+ weight = self.original_layer.weight + delta_w
712
+ norm = weight.norm(p=2, dim=1, keepdim=True)
713
+ weight = (weight / (norm + 1e-6)) * self.dora_m
714
+ self.original_layer.weight.data = weight.to(self.original_layer.weight.dtype)
715
+ else:
716
+ self.original_layer.weight.data += delta_w
717
+
718
+ self.merged = True
719
+
720
+ def unmerge(self):
721
+ if self.r > 0 and self.merged:
722
+ if self.dora:
723
+ print("Warning: DoRA weights cannot be unmerged exactly. Original weights are lost.")
724
+ else:
725
+ weight_b = self.lora_b.weight
726
+ weight_a = self.lora_a.weight
727
+ delta_w = (weight_a @ weight_b.T) * self.scaling
728
+ if self.gate: delta_w *= self.lora_gate
729
+ self.original_layer.weight.data -= delta_w
730
+
731
+ self.merged = False
732
+
733
+ def _forward_impl(self, x: torch.Tensor):
734
+ if self.r > 0 and self.merged:
735
+ return self.original_layer(x)
736
+
737
+ if self.dora and self.r > 0:
738
+ # DoRA embedding
739
+ weight_b = self.lora_b.weight
740
+ weight_a = self.lora_a.weight
741
+ delta_w = (weight_a @ weight_b.T) * self.scaling
742
+ if self.gate: delta_w *= self.lora_gate
743
+
744
+ weight = self.original_layer.weight + delta_w
745
+ norm = weight.norm(p=2, dim=1, keepdim=True)
746
+ weight = (weight / (norm + 1e-6)) * self.dora_m
747
+
748
+ return F.embedding(
749
+ x, weight.to(x.dtype if x.dtype.is_floating_point else self.original_layer.weight.dtype), self.padding_idx,
750
+ self.original_layer.max_norm, self.original_layer.norm_type,
751
+ self.original_layer.scale_grad_by_freq, self.original_layer.sparse
752
+ )
753
+
754
+ result = self.original_layer(x)
755
+
756
+ if self.r > 0:
757
+ # A(x): Look up -> (Batch, Len, r)
758
+ a_out = self.lora_a(x)
759
+ # B(A(x)): Linear -> (Batch, Len, Dim)
760
+ lora_out = self.lora_b(a_out) * self.scaling
761
+ if self.gate: lora_out *= self.lora_gate
762
+ result += lora_out
763
+
764
+ return result
765
+
766
+ def forward(self, x: torch.Tensor):
767
+ if self.gradient_checkpointing and self.training:
768
+ # Embedding inputs (indices) don't have gradients, so we always use the dummy tensor trick
769
+ dummy = torch.tensor(0.0, requires_grad=True, device=x.device)
770
+ return self.checkpoint(lambda d, x: self._forward_impl(x), dummy, x)
771
+ return self._forward_impl(x)
772
+
773
+ def __repr__(self):
774
+ prefix = 'Gated' if self.gate else ''
775
+ suffix = 'DoRA' if self.dora else 'LoRA'
776
+ return f'{self.__class__.__name__}(type={prefix}{suffix}, num={self.num_embeddings}, dim={self.embedding_dim}, r={self.r}, merged={self.merged})'