orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -1
- orbit/callback.py +4 -3
- orbit/dataset/__init__.py +1 -0
- orbit/dataset/cogn.py +138 -0
- orbit/dataset/data/cogn_en.jsonl +45 -0
- orbit/dataset/data/cogn_zh.jsonl +113 -0
- orbit/engine.py +210 -146
- orbit/kit/__init__.py +2 -0
- orbit/kit/interface.py +154 -0
- orbit/kit/wrapper.py +157 -0
- orbit/model/__init__.py +5 -0
- orbit/model/base.py +125 -0
- orbit/model/block/__init__.py +34 -0
- orbit/model/block/attention.py +265 -0
- orbit/model/block/bio.py +537 -0
- orbit/model/block/codebook.py +122 -0
- orbit/model/block/conv.py +505 -0
- orbit/model/block/embedding.py +252 -0
- orbit/model/block/film.py +176 -0
- orbit/model/block/fusion.py +335 -0
- orbit/model/block/gate.py +334 -0
- orbit/model/block/lora.py +776 -0
- orbit/model/block/mlp.py +68 -0
- orbit/model/block/moe.py +94 -0
- orbit/model/block/tcn.py +99 -0
- orbit/model/config.py +62 -0
- orbit/model/kit/__init__.py +6 -0
- orbit/model/kit/discriminator.py +46 -0
- orbit/model/kit/losses.py +193 -0
- orbit/model/motif/__init__.py +0 -0
- orbit/model/motif/vision/__init__.py +0 -0
- orbit/model/motif/vision/v1.py +645 -0
- orbit/model/registry.py +53 -0
- orbit/optim/__init__.py +2 -2
- orbit/optim/sam.py +10 -3
- orbit/plugin/__init__.py +12 -8
- orbit/plugin/board.py +1 -2
- orbit/plugin/checkpoint.py +137 -62
- orbit/plugin/classification.py +2 -2
- orbit/plugin/display_model.py +1 -2
- orbit/plugin/early_stopping.py +1 -2
- orbit/plugin/ema.py +1 -2
- orbit/plugin/gradient_accumulation.py +1 -2
- orbit/plugin/lora.py +346 -0
- orbit/plugin/memory_estimator.py +1 -2
- orbit/plugin/warmup.py +1 -2
- orbit/utils/__init__.py +24 -1
- orbit/utils/cuda.py +10 -0
- orbit/utils/freeze.py +61 -17
- orbit/utils/image.py +164 -0
- orbit/utils/initialization.py +184 -94
- orbit/utils/layer_io.py +66 -7
- orbit/utils/lora.py +480 -0
- orbit/utils/moe.py +55 -0
- orbit/utils/seed.py +3 -19
- orbit/utils/sft.py +93 -0
- orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
- orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
- orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
orbit/utils/lora.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from safetensors.torch import save_file as safe_save_file
|
|
6
|
+
from safetensors.torch import load_file as safe_load_file
|
|
7
|
+
|
|
8
|
+
from orbit.model.block import LinearLoRA, Conv2dLoRA, Conv1dLoRA, EmbeddingLoRA
|
|
9
|
+
|
|
10
|
+
lora_models = [LinearLoRA, Conv2dLoRA, Conv1dLoRA, EmbeddingLoRA]
|
|
11
|
+
|
|
12
|
+
def freeze_backbone_only(
|
|
13
|
+
model: nn.Module,
|
|
14
|
+
unlock_head_keywords: list = None,
|
|
15
|
+
verbose: bool = True
|
|
16
|
+
):
|
|
17
|
+
'''冻结骨干网络,仅保留 LoRA 层和指定的头部层可训练。
|
|
18
|
+
|
|
19
|
+
该函数首先冻结所有参数,然后解冻 LoRA 模块(LinearLoRA, Conv2dLoRA)的参数,
|
|
20
|
+
最后解冻名称中包含 unlock_head_keywords 中任意关键字的参数。
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
model (nn.Module): 目标模型。
|
|
24
|
+
unlock_head_keywords (list, optional): 需要保持解冻状态的头部层关键字列表。
|
|
25
|
+
例如 ['head', 'fc', 'classifier']。默认为 None。
|
|
26
|
+
verbose (bool): 是否打印冻结状态统计信息。默认为 True。
|
|
27
|
+
'''
|
|
28
|
+
for param in model.parameters():
|
|
29
|
+
param.requires_grad = False
|
|
30
|
+
|
|
31
|
+
if not unlock_head_keywords: unlock_head_keywords = []
|
|
32
|
+
|
|
33
|
+
lora_types = tuple(lora_models)
|
|
34
|
+
|
|
35
|
+
lora_counter = 0
|
|
36
|
+
for name, module in model.named_modules():
|
|
37
|
+
if isinstance(module, lora_types):
|
|
38
|
+
lora_counter += 1
|
|
39
|
+
|
|
40
|
+
if hasattr(module, 'lora_a') and module.lora_a is not None:
|
|
41
|
+
if isinstance(module.lora_a, nn.Module):
|
|
42
|
+
for p in module.lora_a.parameters(): p.requires_grad = True
|
|
43
|
+
elif isinstance(module.lora_a, nn.Parameter):
|
|
44
|
+
module.lora_a.requires_grad = True
|
|
45
|
+
|
|
46
|
+
if hasattr(module, 'lora_b') and module.lora_b is not None:
|
|
47
|
+
if isinstance(module.lora_b, nn.Module):
|
|
48
|
+
for p in module.lora_b.parameters(): p.requires_grad = True
|
|
49
|
+
elif isinstance(module.lora_b, nn.Parameter):
|
|
50
|
+
module.lora_b.requires_grad = True
|
|
51
|
+
|
|
52
|
+
if hasattr(module, 'dora_m') and module.dora_m is not None:
|
|
53
|
+
module.dora_m.requires_grad = True
|
|
54
|
+
|
|
55
|
+
if hasattr(module, 'lora_gate') and module.lora_gate is not None:
|
|
56
|
+
module.lora_gate.requires_grad = True
|
|
57
|
+
|
|
58
|
+
for name, param in model.named_parameters():
|
|
59
|
+
if any(k in name for k in unlock_head_keywords):
|
|
60
|
+
param.requires_grad = True
|
|
61
|
+
|
|
62
|
+
if verbose:
|
|
63
|
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
64
|
+
total = sum(p.numel() for p in model.parameters())
|
|
65
|
+
print(f"Backbone frozen.")
|
|
66
|
+
print(f"- Active LoRA blocks: {lora_counter}")
|
|
67
|
+
print(f"- Trainable params: {trainable:,} / {total:,} ({trainable/total:.2%})")
|
|
68
|
+
if unlock_head_keywords:
|
|
69
|
+
print(f"- Extra unlocked layers: {unlock_head_keywords}")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def merge_lora(model: nn.Module, verbose: bool = False):
|
|
73
|
+
'''将模型中所有 LoRA 层的权重合并到原始层中。
|
|
74
|
+
|
|
75
|
+
用于推理加速。遍历模型中的所有模块,找到 LoRA 包装层并调用其 merge 方法。
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
model (nn.Module): 包含 LoRA 层的模型。
|
|
79
|
+
verbose (bool): 是否打印合并统计信息。默认为 False。
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
nn.Module: 合并权重后的模型。
|
|
83
|
+
'''
|
|
84
|
+
lora_types = tuple(lora_models)
|
|
85
|
+
count = 0
|
|
86
|
+
|
|
87
|
+
for module in model.modules():
|
|
88
|
+
if isinstance(module, lora_types):
|
|
89
|
+
# 仅当 r > 0 且尚未合并时才执行合并
|
|
90
|
+
if hasattr(module, 'r') and module.r > 0 and hasattr(module, 'merged') and not module.merged:
|
|
91
|
+
module.merge()
|
|
92
|
+
count += 1
|
|
93
|
+
|
|
94
|
+
if verbose:
|
|
95
|
+
print(f"Merged LoRA weights in {count} modules.")
|
|
96
|
+
|
|
97
|
+
return model
|
|
98
|
+
|
|
99
|
+
def unmerge_lora(model: nn.Module, verbose: bool = False):
|
|
100
|
+
'''撤销模型中所有 LoRA 层的权重合并。
|
|
101
|
+
|
|
102
|
+
用于在推理后恢复训练状态。注意:DoRA 模式下无法精确恢复原始权重。
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
model (nn.Module): 包含 LoRA 层的模型。
|
|
106
|
+
verbose (bool): 是否打印操作信息。默认为 False。
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
nn.Module: 撤销合并后的模型。
|
|
110
|
+
'''
|
|
111
|
+
lora_types = tuple(lora_models)
|
|
112
|
+
count = 0
|
|
113
|
+
|
|
114
|
+
for module in model.modules():
|
|
115
|
+
if isinstance(module, lora_types):
|
|
116
|
+
if hasattr(module, 'merged') and module.merged:
|
|
117
|
+
module.unmerge()
|
|
118
|
+
count += 1
|
|
119
|
+
|
|
120
|
+
if verbose:
|
|
121
|
+
print(f"Unmerged LoRA weights in {count} modules.")
|
|
122
|
+
|
|
123
|
+
return model
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def unload_lora(model: nn.Module, merge: bool = False, verbose: bool = False):
|
|
127
|
+
'''卸载模型中的 LoRA 层,恢复原始层。
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
model (nn.Module): 包含 LoRA 层的模型。
|
|
131
|
+
merge (bool): 是否在卸载前合并权重。默认为 False。
|
|
132
|
+
verbose (bool): 是否打印操作信息。默认为 False。
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
nn.Module: 卸载 LoRA 后的模型。
|
|
136
|
+
'''
|
|
137
|
+
lora_types = tuple(lora_models)
|
|
138
|
+
count = 0
|
|
139
|
+
|
|
140
|
+
def _unload(parent):
|
|
141
|
+
nonlocal count
|
|
142
|
+
for name, child in parent.named_children():
|
|
143
|
+
if isinstance(child, lora_types):
|
|
144
|
+
if merge:
|
|
145
|
+
child.merge()
|
|
146
|
+
|
|
147
|
+
if hasattr(child, 'original_layer'):
|
|
148
|
+
setattr(parent, name, child.original_layer)
|
|
149
|
+
count += 1
|
|
150
|
+
else:
|
|
151
|
+
_unload(child)
|
|
152
|
+
|
|
153
|
+
_unload(model)
|
|
154
|
+
|
|
155
|
+
if verbose:
|
|
156
|
+
action = "Merged and unloaded" if merge else "Unloaded"
|
|
157
|
+
print(f"{action} LoRA layers in {count} modules.")
|
|
158
|
+
|
|
159
|
+
return model
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def inject_lora(
|
|
163
|
+
model: nn.Module,
|
|
164
|
+
r: int = 8,
|
|
165
|
+
lora_alpha: int = 16,
|
|
166
|
+
lora_dropout: float = 0.05,
|
|
167
|
+
gate: bool = False,
|
|
168
|
+
dora: bool = False,
|
|
169
|
+
target_names: list = None,
|
|
170
|
+
exclude_names: list = None,
|
|
171
|
+
verbose: bool = False,
|
|
172
|
+
prefix: str = ""
|
|
173
|
+
):
|
|
174
|
+
'''向模型中注入 LoRA 层。
|
|
175
|
+
|
|
176
|
+
递归遍历模型,将 Linear, Conv2d, Conv1d, Embedding 层替换为对应的 LoRA 包装层。
|
|
177
|
+
支持普通 LoRA、Gated LoRA 和 DoRA。
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
model (nn.Module): 目标模型。
|
|
181
|
+
r (int): LoRA 的秩。默认为 8。
|
|
182
|
+
lora_alpha (int): LoRA 的缩放系数。默认为 16。
|
|
183
|
+
lora_dropout (float): Dropout 概率。默认为 0.05。
|
|
184
|
+
gate (bool): 是否启用 Gated LoRA (添加可学习的门控参数)。默认为 False。
|
|
185
|
+
dora (bool): 是否启用 DoRA (Weight-Decomposed Low-Rank Adaptation)。默认为 False。
|
|
186
|
+
target_names (list, optional): 仅注入名称包含这些关键字的层。支持字符串(子串匹配)或正则表达式对象。默认为 None (注入所有支持的层)。
|
|
187
|
+
exclude_names (list, optional): 排除名称包含这些关键字的层。支持字符串(子串匹配)或正则表达式对象。默认为 None。
|
|
188
|
+
verbose (bool): 是否打印注入的层信息。默认为 False。
|
|
189
|
+
prefix (str): 内部递归使用的路径前缀。
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
nn.Module: 注入 LoRA 后的模型。
|
|
193
|
+
'''
|
|
194
|
+
import re
|
|
195
|
+
|
|
196
|
+
def check_match(name, patterns):
|
|
197
|
+
if patterns is None: return False
|
|
198
|
+
for p in patterns:
|
|
199
|
+
if isinstance(p, str):
|
|
200
|
+
if p in name: return True
|
|
201
|
+
elif hasattr(p, 'search'):
|
|
202
|
+
if p.search(name): return True
|
|
203
|
+
return False
|
|
204
|
+
|
|
205
|
+
for name, child in model.named_children():
|
|
206
|
+
full_name = f"{prefix}.{name}" if prefix else name
|
|
207
|
+
|
|
208
|
+
should_inject = target_names is None or check_match(full_name, target_names)
|
|
209
|
+
is_excluded = check_match(full_name, exclude_names)
|
|
210
|
+
|
|
211
|
+
if not should_inject or is_excluded:
|
|
212
|
+
inject_lora(child, r, lora_alpha, lora_dropout, gate, dora, target_names, exclude_names, verbose, full_name)
|
|
213
|
+
continue
|
|
214
|
+
|
|
215
|
+
new_layer = None
|
|
216
|
+
if isinstance(child, nn.Linear):
|
|
217
|
+
new_layer = LinearLoRA(
|
|
218
|
+
child, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
|
219
|
+
gate=gate, dora=dora
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
elif isinstance(child, nn.Conv2d):
|
|
223
|
+
if child.kernel_size == (1, 1) or child.kernel_size == 1:
|
|
224
|
+
pass
|
|
225
|
+
else:
|
|
226
|
+
new_layer = Conv2dLoRA(
|
|
227
|
+
child, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
|
228
|
+
gate=gate, dora=dora
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
elif isinstance(child, nn.Conv1d):
|
|
232
|
+
new_layer = Conv1dLoRA(
|
|
233
|
+
child, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
|
234
|
+
gate=gate, dora=dora
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
elif isinstance(child, nn.Embedding):
|
|
238
|
+
new_layer = EmbeddingLoRA(
|
|
239
|
+
child, r=r, lora_alpha=lora_alpha,
|
|
240
|
+
gate=gate, dora=dora
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
if new_layer is not None:
|
|
244
|
+
setattr(model, name, new_layer)
|
|
245
|
+
if verbose:
|
|
246
|
+
print(f"LoRA injected: {full_name} ({child.__class__.__name__})")
|
|
247
|
+
else:
|
|
248
|
+
# 如果当前层匹配但不是支持的叶子层(或者是容器),继续递归
|
|
249
|
+
inject_lora(child, r, lora_alpha, lora_dropout, gate, dora, target_names, exclude_names, verbose, full_name)
|
|
250
|
+
|
|
251
|
+
return model
|
|
252
|
+
|
|
253
|
+
def inject_lora_file(
|
|
254
|
+
model: nn.Module,
|
|
255
|
+
path: str,
|
|
256
|
+
merge_and_unload: bool = False,
|
|
257
|
+
verbose: bool = False
|
|
258
|
+
):
|
|
259
|
+
'''从文件自动注入 LoRA 并加载权重。
|
|
260
|
+
|
|
261
|
+
该函数会分析权重文件,自动推断 LoRA 参数(如 r, gate, dora),
|
|
262
|
+
向模型注入相应的 LoRA 层,并加载权重。
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
model (nn.Module): 目标模型。
|
|
266
|
+
path (str): 权重文件路径。
|
|
267
|
+
merge_and_unload (bool): 是否在加载后合并权重并卸载 LoRA 层。默认为 False。
|
|
268
|
+
verbose (bool): 是否打印详细信息。默认为 False。
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
nn.Module: 注入 LoRA 并加载权重后的模型。
|
|
272
|
+
'''
|
|
273
|
+
import os
|
|
274
|
+
if not os.path.exists(path):
|
|
275
|
+
raise FileNotFoundError(f"File not found: {path}")
|
|
276
|
+
|
|
277
|
+
device = next(model.parameters()).device
|
|
278
|
+
|
|
279
|
+
config = {}
|
|
280
|
+
state_dict = {}
|
|
281
|
+
|
|
282
|
+
if path.endswith('.safetensors'):
|
|
283
|
+
# 加载 safetensors
|
|
284
|
+
state_dict = safe_load_file(path, device=str(device))
|
|
285
|
+
|
|
286
|
+
# 尝试读取 metadata
|
|
287
|
+
from safetensors.torch import safe_open
|
|
288
|
+
with safe_open(path, framework="pt", device=str(device)) as f:
|
|
289
|
+
metadata = f.metadata()
|
|
290
|
+
if metadata and 'orbit_lora_config' in metadata:
|
|
291
|
+
try:
|
|
292
|
+
config = json.loads(metadata['orbit_lora_config'])
|
|
293
|
+
except:
|
|
294
|
+
pass
|
|
295
|
+
else:
|
|
296
|
+
checkpoint = torch.load(path, map_location=device)
|
|
297
|
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
298
|
+
state_dict = checkpoint['model_state_dict']
|
|
299
|
+
config = checkpoint.get('orbit_lora_config', {})
|
|
300
|
+
else:
|
|
301
|
+
state_dict = checkpoint
|
|
302
|
+
config = {}
|
|
303
|
+
|
|
304
|
+
r = config.get('r', 8)
|
|
305
|
+
alpha = config.get('alpha', 16)
|
|
306
|
+
target_names = config.get('target_names', None)
|
|
307
|
+
|
|
308
|
+
# 自动检测 gate 和 dora
|
|
309
|
+
has_gate = any('lora_gate' in k for k in state_dict.keys())
|
|
310
|
+
has_dora = any('dora_m' in k for k in state_dict.keys())
|
|
311
|
+
|
|
312
|
+
if verbose:
|
|
313
|
+
print(f"Injecting LoRA from {path}...")
|
|
314
|
+
print(f"Config: r={r}, alpha={alpha}, gate={has_gate}, dora={has_dora}, targets={target_names}")
|
|
315
|
+
|
|
316
|
+
inject_lora(
|
|
317
|
+
model,
|
|
318
|
+
r=r,
|
|
319
|
+
lora_alpha=alpha,
|
|
320
|
+
gate=has_gate,
|
|
321
|
+
dora=has_dora,
|
|
322
|
+
target_names=target_names,
|
|
323
|
+
verbose=verbose
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
|
327
|
+
|
|
328
|
+
if verbose:
|
|
329
|
+
lora_missing = [k for k in missing if 'lora_' in k or 'dora_' in k]
|
|
330
|
+
if lora_missing:
|
|
331
|
+
print(f"Warning: Missing LoRA keys: {lora_missing}")
|
|
332
|
+
else:
|
|
333
|
+
print("LoRA weights loaded successfully.")
|
|
334
|
+
|
|
335
|
+
if merge_and_unload:
|
|
336
|
+
unload_lora(model, merge=True, verbose=verbose)
|
|
337
|
+
|
|
338
|
+
return model
|
|
339
|
+
|
|
340
|
+
def save_lora(model: nn.Module, path: str):
|
|
341
|
+
'''仅保存模型的 LoRA 权重。
|
|
342
|
+
|
|
343
|
+
遍历模型的 state_dict,提取所有键名中包含 'lora_' 的参数并保存。
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
model (nn.Module): 包含 LoRA 层的模型。
|
|
347
|
+
path (str): 保存路径。
|
|
348
|
+
'''
|
|
349
|
+
lora_state_dict = {}
|
|
350
|
+
full_state_dict = model.state_dict()
|
|
351
|
+
|
|
352
|
+
for key, value in full_state_dict.items():
|
|
353
|
+
if 'lora_' in key:
|
|
354
|
+
lora_state_dict[key] = value
|
|
355
|
+
|
|
356
|
+
if path.endswith('.safetensors'):
|
|
357
|
+
safe_save_file(lora_state_dict, path)
|
|
358
|
+
else:
|
|
359
|
+
torch.save(lora_state_dict, path)
|
|
360
|
+
print(f"LoRA weights saved to {path}. Size: {len(lora_state_dict)} keys.")
|
|
361
|
+
|
|
362
|
+
def load_lora(model: nn.Module, path: str):
|
|
363
|
+
'''加载 LoRA 权重到模型中。
|
|
364
|
+
|
|
365
|
+
使用 strict=False 加载权重,并打印缺失或意外的键的警告信息。
|
|
366
|
+
支持加载纯权重文件或 Checkpoint 插件保存的包含 'model_state_dict' 的字典。
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
model (nn.Module): 目标模型。
|
|
370
|
+
path (str): 权重文件路径。
|
|
371
|
+
'''
|
|
372
|
+
device = next(model.parameters()).device
|
|
373
|
+
|
|
374
|
+
if path.endswith('.safetensors'):
|
|
375
|
+
lora_state_dict = safe_load_file(path, device=str(device))
|
|
376
|
+
else:
|
|
377
|
+
checkpoint = torch.load(path, map_location=device)
|
|
378
|
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
379
|
+
lora_state_dict = checkpoint['model_state_dict']
|
|
380
|
+
else:
|
|
381
|
+
lora_state_dict = checkpoint
|
|
382
|
+
|
|
383
|
+
missing, unexpected = model.load_state_dict(lora_state_dict, strict=False)
|
|
384
|
+
|
|
385
|
+
if unexpected:
|
|
386
|
+
print(f"Warning: Unexpected keys found: {unexpected}")
|
|
387
|
+
|
|
388
|
+
lora_missing = [k for k in missing if 'lora_' in k]
|
|
389
|
+
if lora_missing:
|
|
390
|
+
print(f"Warning: Missing LoRA keys: {lora_missing}")
|
|
391
|
+
else:
|
|
392
|
+
print("LoRA weights loaded successfully.")
|
|
393
|
+
|
|
394
|
+
class LoRADiagnoser:
|
|
395
|
+
@staticmethod
|
|
396
|
+
def get_status(model: nn.Module, verbose: bool = False) -> dict:
|
|
397
|
+
"""
|
|
398
|
+
在 train loop 中调用此函数,返回当前 LoRA 层的健康状态。
|
|
399
|
+
"""
|
|
400
|
+
stats = {
|
|
401
|
+
"total_lora_modules": 0,
|
|
402
|
+
"active_grads": 0,
|
|
403
|
+
"avg_update_magnitude": 0.0,
|
|
404
|
+
"max_grad_norm": 0.0,
|
|
405
|
+
"dead_neurons": 0
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
update_ratios = []
|
|
409
|
+
|
|
410
|
+
for name, module in model.named_modules():
|
|
411
|
+
if hasattr(module, 'lora_a') and hasattr(module, 'lora_b') and module.r > 0:
|
|
412
|
+
stats["total_lora_modules"] += 1
|
|
413
|
+
|
|
414
|
+
wa = module.lora_a.weight if isinstance(module.lora_a, nn.Module) else module.lora_a
|
|
415
|
+
wb = module.lora_b.weight if isinstance(module.lora_b, nn.Module) else module.lora_b
|
|
416
|
+
|
|
417
|
+
if wa.grad is not None and wb.grad is not None:
|
|
418
|
+
stats["active_grads"] += 1
|
|
419
|
+
g_norm = wa.grad.norm().item() + wb.grad.norm().item()
|
|
420
|
+
stats["max_grad_norm"] = max(stats["max_grad_norm"], g_norm)
|
|
421
|
+
|
|
422
|
+
s = module.scaling
|
|
423
|
+
|
|
424
|
+
norm_a = wa.data.norm().item()
|
|
425
|
+
norm_b = wb.data.norm().item()
|
|
426
|
+
norm_delta = norm_a * norm_b * s
|
|
427
|
+
|
|
428
|
+
if hasattr(module, 'original_layer'):
|
|
429
|
+
# Conv LoRA
|
|
430
|
+
norm_w = module.original_layer.weight.data.norm().item()
|
|
431
|
+
elif hasattr(module, 'weight'):
|
|
432
|
+
norm_w = module.weight.data.norm().item()
|
|
433
|
+
else:
|
|
434
|
+
norm_w = 1.0 # Fallback
|
|
435
|
+
|
|
436
|
+
ratio = norm_delta / (norm_w + 1e-6)
|
|
437
|
+
update_ratios.append(ratio)
|
|
438
|
+
|
|
439
|
+
if norm_b < 1e-9:
|
|
440
|
+
stats["dead_neurons"] += 1
|
|
441
|
+
|
|
442
|
+
if update_ratios:
|
|
443
|
+
stats["avg_update_magnitude"] = sum(update_ratios) / len(update_ratios)
|
|
444
|
+
stats["min_ratio"] = min(update_ratios)
|
|
445
|
+
stats["max_ratio"] = max(update_ratios)
|
|
446
|
+
|
|
447
|
+
if verbose:
|
|
448
|
+
print(f"--- LoRA Diagnosis ---")
|
|
449
|
+
print(f"Modules: {stats['total_lora_modules']} | Active Grads: {stats['active_grads']}")
|
|
450
|
+
print(f"Update Ratio (Perturbation): {stats['avg_update_magnitude']:.6f} (Target ~0.001 - 0.01)")
|
|
451
|
+
print(f"Max Gradient Norm: {stats['max_grad_norm']:.6f}")
|
|
452
|
+
if stats['dead_neurons'] > 0:
|
|
453
|
+
print(f"Warning: {stats['dead_neurons']} modules have near-zero output (Initialization issue?)")
|
|
454
|
+
print(f"----------------------")
|
|
455
|
+
|
|
456
|
+
return stats
|
|
457
|
+
|
|
458
|
+
@staticmethod
|
|
459
|
+
def check_collapse(model: nn.Module, threshold: float = 1e-4):
|
|
460
|
+
"""
|
|
461
|
+
检查 LoRA 矩阵是否存在严重的秩塌缩 (Rank Collapse)。
|
|
462
|
+
"""
|
|
463
|
+
print("Running SVD analysis on LoRA layers...")
|
|
464
|
+
for name, module in model.named_modules():
|
|
465
|
+
if hasattr(module, 'lora_a') and hasattr(module, 'lora_b'):
|
|
466
|
+
wa = module.lora_a.weight if isinstance(module.lora_a, nn.Module) else module.lora_a
|
|
467
|
+
|
|
468
|
+
if wa.dim() > 2:
|
|
469
|
+
wa_flat = wa.view(wa.shape[0], -1) # (r, in*k*k)
|
|
470
|
+
else:
|
|
471
|
+
wa_flat = wa
|
|
472
|
+
|
|
473
|
+
try:
|
|
474
|
+
_, S, _ = torch.svd(wa_flat)
|
|
475
|
+
# 归一化奇异值
|
|
476
|
+
S = S / S[0]
|
|
477
|
+
effective_rank = (S > threshold).sum().item()
|
|
478
|
+
print(f"[{name}] Rank: {module.r} | Effective Rank: {effective_rank} | Top/Bottom Ratio: {S[0]/S[-1]:.1f}")
|
|
479
|
+
except:
|
|
480
|
+
pass
|
orbit/utils/moe.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
from typing import Literal
|
|
3
|
+
from orbit.utils.freeze import set_trainable
|
|
4
|
+
|
|
5
|
+
def set_moe_training_mode(
|
|
6
|
+
model: nn.Module,
|
|
7
|
+
mode: Literal['all', 'router_only', 'experts_only'] = 'all',
|
|
8
|
+
verbose: bool = True
|
|
9
|
+
) -> None:
|
|
10
|
+
'''设置 MoE (Mixture of Experts) 模型的训练模式。
|
|
11
|
+
|
|
12
|
+
通过识别模型中的 MoE 模块实例,精确控制 Router (门控网络) 和 Experts (专家网络) 的冻结/解冻状态。
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
model (nn.Module): 包含 MoE 结构的模型。
|
|
16
|
+
mode (str): 训练模式。
|
|
17
|
+
- 'all': 联合训练 Router 和 Experts (默认)。
|
|
18
|
+
- 'router_only': 仅训练 Router,冻结 Experts。
|
|
19
|
+
- 'experts_only': 仅训练 Experts,冻结 Router。
|
|
20
|
+
verbose (bool): 是否打印状态变更信息。
|
|
21
|
+
'''
|
|
22
|
+
from orbit.model.block.moe import MoE
|
|
23
|
+
|
|
24
|
+
moe_modules = [m for m in model.modules() if isinstance(m, MoE)]
|
|
25
|
+
|
|
26
|
+
if not moe_modules:
|
|
27
|
+
if verbose:
|
|
28
|
+
print("Warning: No MoE modules found in the provided model.")
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
count = len(moe_modules)
|
|
32
|
+
|
|
33
|
+
for moe in moe_modules:
|
|
34
|
+
if mode == 'all':
|
|
35
|
+
set_trainable(moe.router, trainable=True)
|
|
36
|
+
set_trainable(moe.experts, trainable=True)
|
|
37
|
+
|
|
38
|
+
elif mode == 'router_only':
|
|
39
|
+
set_trainable(moe.experts, trainable=False)
|
|
40
|
+
set_trainable(moe.router, trainable=True)
|
|
41
|
+
|
|
42
|
+
elif mode == 'experts_only':
|
|
43
|
+
set_trainable(moe.router, trainable=False)
|
|
44
|
+
set_trainable(moe.experts, trainable=True)
|
|
45
|
+
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError(f"Unknown mode: {mode}. Supported modes: 'all', 'router_only', 'experts_only'")
|
|
48
|
+
|
|
49
|
+
if verbose:
|
|
50
|
+
mode_desc = {
|
|
51
|
+
'all': "ALL (Router: Unfrozen, Experts: Unfrozen)",
|
|
52
|
+
'router_only': "ROUTER_ONLY (Router: Unfrozen, Experts: Frozen)",
|
|
53
|
+
'experts_only': "EXPERTS_ONLY (Router: Frozen, Experts: Unfrozen)"
|
|
54
|
+
}
|
|
55
|
+
print(f"MoE Training Mode set to: {mode_desc[mode]} for {count} MoE module(s).")
|
orbit/utils/seed.py
CHANGED
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
|
4
4
|
import random
|
|
5
5
|
import os
|
|
6
6
|
|
|
7
|
-
def seed_everything(seed=42, strict=False):
|
|
7
|
+
def seed_everything(seed=42, strict=False, warn_only=True):
|
|
8
8
|
"""
|
|
9
9
|
设置所有随机种子以确保 PyTorch 实验的可复现性。
|
|
10
10
|
|
|
@@ -16,38 +16,22 @@ def seed_everything(seed=42, strict=False):
|
|
|
16
16
|
"""
|
|
17
17
|
import orbit
|
|
18
18
|
orbit.seed_info = seed
|
|
19
|
-
# 1. 设置 Python 原生 random
|
|
20
19
|
random.seed(seed)
|
|
20
|
+
np.random.seed(seed)
|
|
21
21
|
|
|
22
|
-
# 2. 设置 Python 哈希种子 (影响字典/集合迭代顺序)
|
|
23
22
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
24
23
|
|
|
25
|
-
# 3. 设置 Numpy
|
|
26
|
-
np.random.seed(seed)
|
|
27
|
-
|
|
28
|
-
# 4. 设置 PyTorch CPU/GPU
|
|
29
24
|
torch.manual_seed(seed)
|
|
30
25
|
if torch.cuda.is_available():
|
|
31
26
|
torch.cuda.manual_seed_all(seed)
|
|
32
27
|
|
|
33
|
-
# 5. 设置 CuDNN 后端 (常规复现性设置)
|
|
34
28
|
if torch.cuda.is_available():
|
|
35
|
-
# 禁止寻找最优算法 (因为最优算法可能因硬件状态而变)
|
|
36
29
|
torch.backends.cudnn.benchmark = False
|
|
37
|
-
# 强制使用确定性算法
|
|
38
30
|
torch.backends.cudnn.deterministic = True
|
|
39
|
-
# 6. 严格模式 (Strict Mode)
|
|
40
31
|
if strict:
|
|
41
32
|
try:
|
|
42
|
-
|
|
43
|
-
# 注意:某些操作如果 PyTorch 没有对应的确定性实现,会直接通过 RuntimeError 报错
|
|
44
|
-
torch.use_deterministic_algorithms(True)
|
|
45
|
-
|
|
46
|
-
# 为了让 use_deterministic_algorithms 在 CUDA 上正常工作,
|
|
47
|
-
# 必须设置 CUBLAS_WORKSPACE_CONFIG,否则会报 CuBLAS 错误。
|
|
48
|
-
# :4096:8 是官方推荐的设置,虽然会增加少许显存开销
|
|
33
|
+
torch.use_deterministic_algorithms(True, warn_only=warn_only)
|
|
49
34
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
|
50
|
-
|
|
51
35
|
print(f"[Info] Strict deterministic mode enabled. (seed={seed})")
|
|
52
36
|
except AttributeError:
|
|
53
37
|
print("[Warning] torch.use_deterministic_algorithms is not available in your PyTorch version.")
|
orbit/utils/sft.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from typing import Dict, Any, TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING: from orbit.engine import Engine
|
|
4
|
+
|
|
5
|
+
def build_sft(
|
|
6
|
+
user_content: str,
|
|
7
|
+
model_content: str,
|
|
8
|
+
tokenizer: Any,
|
|
9
|
+
max_length: int = 2048,
|
|
10
|
+
model_role: str = 'model',
|
|
11
|
+
padding: bool = True,
|
|
12
|
+
ignore_index: int = -100
|
|
13
|
+
) -> Dict[str, Any]:
|
|
14
|
+
'''构建 SFT (Supervised Fine-Tuning) 数据集样本。
|
|
15
|
+
|
|
16
|
+
将用户输入和模型输出组合,应用对话模板,并进行分词、截断和 padding。
|
|
17
|
+
同时生成 labels,其中用户输入部分和 padding 部分被 mask (设置为 ignore_index)。
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
user_content (str): 用户的输入内容。
|
|
21
|
+
model_content (str): 模型的期望回复内容。
|
|
22
|
+
tokenizer (Any): 分词器实例,需要支持 apply_chat_template 和 __call__。
|
|
23
|
+
max_length (int, optional): 序列最大长度。默认为 2048。
|
|
24
|
+
model_role (str, optional): 模型角色名称,用于构建对话消息。默认为 'model'。
|
|
25
|
+
padding (bool, optional): 是否进行 padding 到 max_length。默认为 True。
|
|
26
|
+
ignore_index (int, optional): 用于 mask labels 的索引值,与 PyTorch 损失函数保持一致。默认为 -100。
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Dict[str, Any]: 包含处理后的数据字典:
|
|
30
|
+
- 'input_ids': 输入 token ID 张量。
|
|
31
|
+
- 'attention_mask': 注意力掩码张量。
|
|
32
|
+
- 'labels': 用于计算损失的标签张量,用户部分已 mask。
|
|
33
|
+
'''
|
|
34
|
+
messages = [
|
|
35
|
+
{'role': 'user', 'content': user_content},
|
|
36
|
+
{'role': model_role, 'content': model_content}
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
text = tokenizer.apply_chat_template(
|
|
40
|
+
messages,
|
|
41
|
+
tokenize=False,
|
|
42
|
+
add_generation_prompt=False
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
user_messages = [{'role': 'user', 'content': user_content}]
|
|
46
|
+
prompt_text = tokenizer.apply_chat_template(
|
|
47
|
+
user_messages,
|
|
48
|
+
tokenize=False,
|
|
49
|
+
add_generation_prompt=True
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
encodings = tokenizer(
|
|
53
|
+
text,
|
|
54
|
+
max_length=max_length,
|
|
55
|
+
truncation=True,
|
|
56
|
+
padding='max_length' if padding else False,
|
|
57
|
+
return_tensors='pt'
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
input_ids = encodings['input_ids'][0]
|
|
61
|
+
attention_mask = encodings['attention_mask'][0]
|
|
62
|
+
labels = input_ids.clone()
|
|
63
|
+
|
|
64
|
+
prompt_encodings = tokenizer(
|
|
65
|
+
prompt_text,
|
|
66
|
+
truncation=True,
|
|
67
|
+
max_length=max_length,
|
|
68
|
+
add_special_tokens=True
|
|
69
|
+
)
|
|
70
|
+
prompt_len = len(prompt_encodings['input_ids'])
|
|
71
|
+
|
|
72
|
+
if prompt_len > len(input_ids):
|
|
73
|
+
prompt_len = len(input_ids)
|
|
74
|
+
|
|
75
|
+
labels[:prompt_len] = ignore_index
|
|
76
|
+
|
|
77
|
+
if tokenizer.pad_token_id is not None:
|
|
78
|
+
labels[input_ids == tokenizer.pad_token_id] = ignore_index
|
|
79
|
+
|
|
80
|
+
return {
|
|
81
|
+
'input_ids': input_ids,
|
|
82
|
+
'attention_mask': attention_mask,
|
|
83
|
+
'labels': labels
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
def train_sft(engine: 'Engine'):
|
|
87
|
+
output = engine.unwrap_model()(
|
|
88
|
+
input_ids=engine.data['input_ids'],
|
|
89
|
+
attention_mask=engine.data['attention_mask'],
|
|
90
|
+
labels=engine.data['labels']
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
engine.update(output.loss)
|