orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -1
- orbit/callback.py +4 -3
- orbit/dataset/__init__.py +1 -0
- orbit/dataset/cogn.py +138 -0
- orbit/dataset/data/cogn_en.jsonl +45 -0
- orbit/dataset/data/cogn_zh.jsonl +113 -0
- orbit/engine.py +210 -146
- orbit/kit/__init__.py +2 -0
- orbit/kit/interface.py +154 -0
- orbit/kit/wrapper.py +157 -0
- orbit/model/__init__.py +5 -0
- orbit/model/base.py +125 -0
- orbit/model/block/__init__.py +34 -0
- orbit/model/block/attention.py +265 -0
- orbit/model/block/bio.py +537 -0
- orbit/model/block/codebook.py +122 -0
- orbit/model/block/conv.py +505 -0
- orbit/model/block/embedding.py +252 -0
- orbit/model/block/film.py +176 -0
- orbit/model/block/fusion.py +335 -0
- orbit/model/block/gate.py +334 -0
- orbit/model/block/lora.py +776 -0
- orbit/model/block/mlp.py +68 -0
- orbit/model/block/moe.py +94 -0
- orbit/model/block/tcn.py +99 -0
- orbit/model/config.py +62 -0
- orbit/model/kit/__init__.py +6 -0
- orbit/model/kit/discriminator.py +46 -0
- orbit/model/kit/losses.py +193 -0
- orbit/model/motif/__init__.py +0 -0
- orbit/model/motif/vision/__init__.py +0 -0
- orbit/model/motif/vision/v1.py +645 -0
- orbit/model/registry.py +53 -0
- orbit/optim/__init__.py +2 -2
- orbit/optim/sam.py +10 -3
- orbit/plugin/__init__.py +12 -8
- orbit/plugin/board.py +1 -2
- orbit/plugin/checkpoint.py +137 -62
- orbit/plugin/classification.py +2 -2
- orbit/plugin/display_model.py +1 -2
- orbit/plugin/early_stopping.py +1 -2
- orbit/plugin/ema.py +1 -2
- orbit/plugin/gradient_accumulation.py +1 -2
- orbit/plugin/lora.py +346 -0
- orbit/plugin/memory_estimator.py +1 -2
- orbit/plugin/warmup.py +1 -2
- orbit/utils/__init__.py +24 -1
- orbit/utils/cuda.py +10 -0
- orbit/utils/freeze.py +61 -17
- orbit/utils/image.py +164 -0
- orbit/utils/initialization.py +184 -94
- orbit/utils/layer_io.py +66 -7
- orbit/utils/lora.py +480 -0
- orbit/utils/moe.py +55 -0
- orbit/utils/seed.py +3 -19
- orbit/utils/sft.py +93 -0
- orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
- orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
- orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
orbit/plugin/lora.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import inspect
|
|
4
|
+
import json
|
|
5
|
+
|
|
6
|
+
from typing import Optional, List, TYPE_CHECKING
|
|
7
|
+
from orbit.callback import Event
|
|
8
|
+
from orbit.plugin.checkpoint import Checkpoint, safe_save_file, safe_load_file, safe_open
|
|
9
|
+
from orbit.utils.lora import inject_lora, freeze_backbone_only
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING: from orbit.engine import Engine
|
|
12
|
+
|
|
13
|
+
class LoRA(Checkpoint):
|
|
14
|
+
'''LoRA 插件:集成 LoRA 注入、冻结、轻量化保存与加载功能。
|
|
15
|
+
|
|
16
|
+
该插件继承自 Checkpoint,在训练开始时自动将模型转换为 LoRA 模型,
|
|
17
|
+
重建优化器以适应新参数,并提供仅保存训练参数的轻量化 Checkpoint 功能。
|
|
18
|
+
'''
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
name: str = "lora_model",
|
|
22
|
+
path: str = "checkpoints",
|
|
23
|
+
# LoRA 参数
|
|
24
|
+
r: int = 8,
|
|
25
|
+
lora_alpha: int = 16,
|
|
26
|
+
lora_dropout: float = 0.05,
|
|
27
|
+
target_names: Optional[List[str]] = None,
|
|
28
|
+
exclude_names: Optional[List[str]] = None,
|
|
29
|
+
unlock_head_keywords: Optional[List[str]] = None,
|
|
30
|
+
gate: bool = False,
|
|
31
|
+
dora: bool = False,
|
|
32
|
+
# Checkpoint 参数
|
|
33
|
+
monitor: str = 'val_loss',
|
|
34
|
+
mode: str = 'min',
|
|
35
|
+
save_top_k: int = 1,
|
|
36
|
+
save_last: bool = True,
|
|
37
|
+
every_n_train_steps: Optional[int] = None,
|
|
38
|
+
use_safetensors: bool = False,
|
|
39
|
+
verbose: bool = True
|
|
40
|
+
):
|
|
41
|
+
'''初始化 LoRA 插件。
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
name (str): 模型名称前缀。
|
|
45
|
+
path (str): Checkpoint 保存目录。
|
|
46
|
+
r (int): LoRA 秩。
|
|
47
|
+
lora_alpha (int): LoRA 缩放系数。
|
|
48
|
+
lora_dropout (float): LoRA Dropout。
|
|
49
|
+
target_names (list, optional): 仅注入包含这些名称的层。
|
|
50
|
+
exclude_names (list, optional): 排除包含这些名称的层。
|
|
51
|
+
unlock_head_keywords (list, optional): 除了 LoRA 层外,还需要解冻的层关键字(如分类头)。
|
|
52
|
+
gate (bool): 是否使用 Gated LoRA。
|
|
53
|
+
dora (bool): 是否使用 DoRA。
|
|
54
|
+
monitor (str): 监控指标。
|
|
55
|
+
mode (str): 监控指标模式 ('min'/'max')。
|
|
56
|
+
save_top_k (int): 保存最佳模型数量。
|
|
57
|
+
save_last (bool): 是否保存最后的模型。
|
|
58
|
+
every_n_train_steps (int, optional): 每 N 步保存一次。
|
|
59
|
+
use_safetensors (bool): 是否使用 safetensors 格式保存。
|
|
60
|
+
verbose (bool): 是否打印日志。
|
|
61
|
+
'''
|
|
62
|
+
# 初始化 Checkpoint,强制 save_weights_only=False 以保留训练状态,
|
|
63
|
+
# 但我们会在 _save 中自定义过滤逻辑。
|
|
64
|
+
super().__init__(
|
|
65
|
+
name=name,
|
|
66
|
+
path=path,
|
|
67
|
+
save_weights_only=False,
|
|
68
|
+
monitor=monitor,
|
|
69
|
+
mode=mode,
|
|
70
|
+
save_top_k=save_top_k,
|
|
71
|
+
save_last=save_last,
|
|
72
|
+
every_n_train_steps=every_n_train_steps,
|
|
73
|
+
use_safetensors=use_safetensors,
|
|
74
|
+
verbose=verbose
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
self.r = r
|
|
78
|
+
self.lora_alpha = lora_alpha
|
|
79
|
+
self.lora_dropout = lora_dropout
|
|
80
|
+
self.target_names = target_names
|
|
81
|
+
self.exclude_names = exclude_names
|
|
82
|
+
self.unlock_head_keywords = unlock_head_keywords
|
|
83
|
+
self.gate = gate
|
|
84
|
+
self.dora = dora
|
|
85
|
+
|
|
86
|
+
self.injected = False
|
|
87
|
+
|
|
88
|
+
def on_init(self, event: Event):
|
|
89
|
+
engine = event.engine
|
|
90
|
+
|
|
91
|
+
# 0. 冲突检测:检查是否存在其他 Checkpoint 插件
|
|
92
|
+
other_checkpoints = [p for p in engine.plugins if isinstance(p, Checkpoint) and p is not self]
|
|
93
|
+
if other_checkpoints:
|
|
94
|
+
engine.print("[yellow]Warning: Multiple Checkpoint plugins detected. Since 'LoRA' inherits from 'Checkpoint', using both may cause conflicts (e.g. double saving). Suggest removing the standard 'Checkpoint' plugin.[/]", plugin='LoRA')
|
|
95
|
+
|
|
96
|
+
model = engine.unwrap_model()
|
|
97
|
+
|
|
98
|
+
# 1. 注入 LoRA 并冻结骨干
|
|
99
|
+
if not self.injected:
|
|
100
|
+
engine.print(f"[cyan]Injecting LoRA (r={self.r}, alpha={self.lora_alpha})...[/]", plugin='LoRA')
|
|
101
|
+
inject_lora(
|
|
102
|
+
model,
|
|
103
|
+
r=self.r,
|
|
104
|
+
lora_alpha=self.lora_alpha,
|
|
105
|
+
lora_dropout=self.lora_dropout,
|
|
106
|
+
gate=self.gate,
|
|
107
|
+
dora=self.dora,
|
|
108
|
+
target_names=self.target_names,
|
|
109
|
+
exclude_names=self.exclude_names
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
freeze_backbone_only(
|
|
113
|
+
model,
|
|
114
|
+
unlock_head_keywords=self.unlock_head_keywords,
|
|
115
|
+
verbose=self.verbose
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
model.to(engine.device)
|
|
119
|
+
|
|
120
|
+
# 2. 重建 Optimizer
|
|
121
|
+
# 由于参数集发生变化(新增了 LoRA 参数,冻结了大部分参数),旧的优化器无法使用。
|
|
122
|
+
# 我们尝试使用旧优化器的类和 defaults 重新初始化。
|
|
123
|
+
if engine.optimizer:
|
|
124
|
+
old_opt = engine.optimizer
|
|
125
|
+
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
|
126
|
+
|
|
127
|
+
if not trainable_params:
|
|
128
|
+
engine.print("[red]Warning: No trainable parameters found after LoRA injection![/]", plugin='LoRA')
|
|
129
|
+
else:
|
|
130
|
+
opt_cls = old_opt.__class__
|
|
131
|
+
defaults = old_opt.defaults
|
|
132
|
+
|
|
133
|
+
# 过滤掉不在构造函数中的参数,防止某些库(如 transformers)在 defaults 中添加了额外元数据导致重建失败
|
|
134
|
+
sig = inspect.signature(opt_cls.__init__)
|
|
135
|
+
has_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
|
|
136
|
+
if has_kwargs:
|
|
137
|
+
filtered_defaults = defaults
|
|
138
|
+
else:
|
|
139
|
+
filtered_defaults = {k: v for k, v in defaults.items() if k in sig.parameters}
|
|
140
|
+
|
|
141
|
+
engine.print(f"[cyan]Re-initializing Optimizer {opt_cls.__name__} for {len(trainable_params)} trainable groups...[/]", plugin='LoRA')
|
|
142
|
+
|
|
143
|
+
# 创建新优化器
|
|
144
|
+
if opt_cls.__name__ == 'SAM':
|
|
145
|
+
base_opt_cls = old_opt.base_optimizer.__class__
|
|
146
|
+
new_opt = opt_cls(trainable_params, base_optimizer=base_opt_cls, **filtered_defaults)
|
|
147
|
+
elif opt_cls.__name__ == 'Muon':
|
|
148
|
+
muon_params = [p for p in trainable_params if p.ndim == 2]
|
|
149
|
+
adamw_params = [p for p in trainable_params if p.ndim != 2]
|
|
150
|
+
new_opt = opt_cls(muon_params=muon_params, adamw_params=adamw_params, **filtered_defaults)
|
|
151
|
+
else:
|
|
152
|
+
new_opt = opt_cls(trainable_params, **filtered_defaults)
|
|
153
|
+
|
|
154
|
+
# 替换 Engine 中的优化器
|
|
155
|
+
engine.optimizer = new_opt
|
|
156
|
+
|
|
157
|
+
# 处理 Scheduler 的关联
|
|
158
|
+
if engine.scheduler:
|
|
159
|
+
if hasattr(engine.scheduler, 'optimizer'):
|
|
160
|
+
engine.scheduler.optimizer = new_opt
|
|
161
|
+
engine.print("[yellow]Note: Scheduler optimizer reference updated.[/]", plugin='LoRA')
|
|
162
|
+
else:
|
|
163
|
+
engine.print("[yellow]Warning: Scheduler exists but cannot auto-update optimizer reference. Please verify scheduler behavior.[/]", plugin='LoRA')
|
|
164
|
+
|
|
165
|
+
self.injected = True
|
|
166
|
+
|
|
167
|
+
# 3. 执行 Checkpoint 的初始化 (创建目录, 尝试 Resume)
|
|
168
|
+
super().on_init(event)
|
|
169
|
+
|
|
170
|
+
def _save(self, engine: 'Engine', filename: str, verbose: bool = True, is_step: bool = False):
|
|
171
|
+
"""重写保存逻辑,仅保存可训练参数 (LoRA + Heads) 和训练状态。"""
|
|
172
|
+
|
|
173
|
+
if self.monitor:
|
|
174
|
+
engine.meta[self._meta_key] = {'best_k_models': self.best_k_models}
|
|
175
|
+
|
|
176
|
+
raw_model = engine.unwrap_model()
|
|
177
|
+
full_state_dict = raw_model.state_dict()
|
|
178
|
+
lora_state_dict = {}
|
|
179
|
+
|
|
180
|
+
# 筛选: 保存 requires_grad 的参数以及 LoRA/DoRA 相关的键
|
|
181
|
+
for name, param in raw_model.named_parameters():
|
|
182
|
+
if param.requires_grad:
|
|
183
|
+
lora_state_dict[name] = full_state_dict[name]
|
|
184
|
+
|
|
185
|
+
# 确保 buffers (如 BN running stats) 在解冻层中也被保存
|
|
186
|
+
# 同时也保存所有名字中带 lora/dora 的 buffer (虽然通常它们没有 buffer)
|
|
187
|
+
for key, value in full_state_dict.items():
|
|
188
|
+
if 'lora_' in key or 'dora_' in key:
|
|
189
|
+
if key not in lora_state_dict:
|
|
190
|
+
lora_state_dict[key] = value
|
|
191
|
+
|
|
192
|
+
lora_config = {
|
|
193
|
+
'r': self.r,
|
|
194
|
+
'alpha': self.lora_alpha,
|
|
195
|
+
'target_names': self.target_names,
|
|
196
|
+
'unlock_head_keywords': self.unlock_head_keywords
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
file_path = os.path.join(self.path, filename)
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
if self.use_safetensors and filename.endswith('.safetensors'):
|
|
203
|
+
# Safetensors 模式
|
|
204
|
+
metadata = {
|
|
205
|
+
'epoch': str(engine.epoch),
|
|
206
|
+
'global_step': str(engine.global_step),
|
|
207
|
+
'batch_idx': str(engine.batch_idx),
|
|
208
|
+
'is_step': str(is_step),
|
|
209
|
+
'orbit_lora_config': json.dumps(lora_config)
|
|
210
|
+
}
|
|
211
|
+
try:
|
|
212
|
+
metadata['meta'] = json.dumps(engine.meta)
|
|
213
|
+
except: pass
|
|
214
|
+
|
|
215
|
+
safe_save_file(lora_state_dict, file_path, metadata=metadata)
|
|
216
|
+
|
|
217
|
+
else:
|
|
218
|
+
# Torch 模式
|
|
219
|
+
state = {
|
|
220
|
+
'epoch': engine.epoch,
|
|
221
|
+
'global_step': engine.global_step,
|
|
222
|
+
'batch_idx': engine.batch_idx,
|
|
223
|
+
'is_step': is_step,
|
|
224
|
+
'model_state_dict': lora_state_dict,
|
|
225
|
+
'optimizer_state_dict': engine.optimizer.state_dict() if engine.optimizer else None,
|
|
226
|
+
'scheduler_state_dict': engine.scheduler.state_dict() if engine.scheduler else None,
|
|
227
|
+
'scaler_state_dict': engine.scaler.state_dict() if engine.scaler else None,
|
|
228
|
+
'meta': engine.meta,
|
|
229
|
+
'orbit_lora_config': lora_config
|
|
230
|
+
}
|
|
231
|
+
torch.save(state, file_path)
|
|
232
|
+
|
|
233
|
+
if verbose:
|
|
234
|
+
rel_path = os.path.relpath(file_path)
|
|
235
|
+
file_size = os.path.getsize(file_path) / 1024 / 1024
|
|
236
|
+
engine.print(f"Saved LoRA checkpoint: {rel_path} ({file_size:.2f} MB)", plugin='LoRA')
|
|
237
|
+
except Exception as e:
|
|
238
|
+
engine.print(f"[red]Failed to save checkpoint: {e}[/]", plugin='LoRA')
|
|
239
|
+
|
|
240
|
+
def _load(self, engine: 'Engine', file_path: str):
|
|
241
|
+
"""重写加载逻辑,支持 strict=False 加载。"""
|
|
242
|
+
engine.print(f"[cyan]Loading LoRA checkpoint from: {file_path}[/]", plugin='LoRA')
|
|
243
|
+
try:
|
|
244
|
+
raw_model = engine.unwrap_model()
|
|
245
|
+
model_sd = None
|
|
246
|
+
saved_config = {}
|
|
247
|
+
|
|
248
|
+
if file_path.endswith('.safetensors'):
|
|
249
|
+
model_sd = safe_load_file(file_path, device=str(engine.device))
|
|
250
|
+
|
|
251
|
+
with safe_open(file_path, framework="pt", device=str(engine.device)) as f:
|
|
252
|
+
metadata = f.metadata()
|
|
253
|
+
if metadata and 'orbit_lora_config' in metadata:
|
|
254
|
+
try:
|
|
255
|
+
saved_config = json.loads(metadata['orbit_lora_config'])
|
|
256
|
+
except: pass
|
|
257
|
+
else:
|
|
258
|
+
checkpoint = torch.load(file_path, map_location=engine.device)
|
|
259
|
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
260
|
+
model_sd = checkpoint['model_state_dict']
|
|
261
|
+
saved_config = checkpoint.get('orbit_lora_config', {})
|
|
262
|
+
else:
|
|
263
|
+
model_sd = checkpoint if not isinstance(checkpoint, dict) else checkpoint
|
|
264
|
+
|
|
265
|
+
# 1. 加载模型参数
|
|
266
|
+
if model_sd is not None:
|
|
267
|
+
# 配置检查
|
|
268
|
+
if saved_config:
|
|
269
|
+
if saved_config.get('r') != self.r:
|
|
270
|
+
engine.print(f"[yellow]Warning: Loaded LoRA rank ({saved_config.get('r')}) != Current rank ({self.r})[/]", plugin='LoRA')
|
|
271
|
+
|
|
272
|
+
missing, unexpected = raw_model.load_state_dict(model_sd, strict=False)
|
|
273
|
+
|
|
274
|
+
# 过滤掉骨干网络的缺失警告,只关注 LoRA 部分
|
|
275
|
+
relevant_missing = [k for k in missing if 'lora_' in k or 'dora_' in k or any(h in k for h in (self.unlock_head_keywords or []))]
|
|
276
|
+
if relevant_missing:
|
|
277
|
+
engine.print(f"[yellow]Warning: Missing relevant keys: {relevant_missing}[/]", plugin='LoRA')
|
|
278
|
+
else:
|
|
279
|
+
engine.print("[green]LoRA weights loaded successfully.[/]", plugin='LoRA')
|
|
280
|
+
|
|
281
|
+
# 2. 恢复训练状态
|
|
282
|
+
if file_path.endswith('.safetensors'):
|
|
283
|
+
# Safetensors 仅恢复元数据
|
|
284
|
+
with safe_open(file_path, framework="pt", device=str(engine.device)) as f:
|
|
285
|
+
metadata = f.metadata()
|
|
286
|
+
if metadata:
|
|
287
|
+
loaded_epoch = int(metadata.get('epoch', 0))
|
|
288
|
+
loaded_batch_idx = int(metadata.get('batch_idx', -1))
|
|
289
|
+
is_step = metadata.get('is_step', 'False') == 'True'
|
|
290
|
+
engine.global_step = int(metadata.get('global_step', 0))
|
|
291
|
+
|
|
292
|
+
if 'meta' in metadata:
|
|
293
|
+
try:
|
|
294
|
+
engine.meta.update(json.loads(metadata['meta']))
|
|
295
|
+
except: pass
|
|
296
|
+
|
|
297
|
+
if is_step:
|
|
298
|
+
engine.start_epoch = loaded_epoch
|
|
299
|
+
engine.start_batch_idx = loaded_batch_idx
|
|
300
|
+
msg = f"Epoch {engine.start_epoch}, Batch {engine.start_batch_idx + 1}"
|
|
301
|
+
else:
|
|
302
|
+
engine.start_epoch = loaded_epoch + 1
|
|
303
|
+
engine.start_batch_idx = -1
|
|
304
|
+
msg = f"Epoch {engine.start_epoch}"
|
|
305
|
+
|
|
306
|
+
engine.print(f"[green]Resumed training from {msg}. Note: Optimizer state not restored (safetensors).[/]", plugin='LoRA')
|
|
307
|
+
else:
|
|
308
|
+
# Torch 恢复完整状态
|
|
309
|
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
310
|
+
if engine.optimizer and 'optimizer_state_dict' in checkpoint:
|
|
311
|
+
try:
|
|
312
|
+
engine.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
313
|
+
except Exception as e:
|
|
314
|
+
engine.print(f"[yellow]Warning: Failed to load optimizer state: {e}.[/]", plugin='LoRA')
|
|
315
|
+
|
|
316
|
+
if engine.scheduler and 'scheduler_state_dict' in checkpoint:
|
|
317
|
+
try:
|
|
318
|
+
engine.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
319
|
+
except: pass
|
|
320
|
+
|
|
321
|
+
if engine.scaler and 'scaler_state_dict' in checkpoint:
|
|
322
|
+
engine.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
323
|
+
|
|
324
|
+
if 'meta' in checkpoint:
|
|
325
|
+
engine.meta.update(checkpoint['meta'])
|
|
326
|
+
|
|
327
|
+
loaded_epoch = checkpoint.get('epoch', 0)
|
|
328
|
+
loaded_batch_idx = checkpoint.get('batch_idx', -1)
|
|
329
|
+
is_step = checkpoint.get('is_step', False)
|
|
330
|
+
|
|
331
|
+
if is_step:
|
|
332
|
+
engine.start_epoch = loaded_epoch
|
|
333
|
+
engine.start_batch_idx = loaded_batch_idx
|
|
334
|
+
msg = f"Epoch {engine.start_epoch}, Batch {engine.start_batch_idx + 1}"
|
|
335
|
+
else:
|
|
336
|
+
engine.start_epoch = loaded_epoch + 1
|
|
337
|
+
engine.start_batch_idx = -1
|
|
338
|
+
msg = f"Epoch {engine.start_epoch}"
|
|
339
|
+
|
|
340
|
+
engine.global_step = checkpoint.get('global_step', 0)
|
|
341
|
+
engine.print(f"[green]Resuming training from {msg}[/]", plugin='LoRA')
|
|
342
|
+
|
|
343
|
+
except Exception as e:
|
|
344
|
+
engine.print(f"[red]Failed to load checkpoint: {e}[/]", plugin='LoRA')
|
|
345
|
+
import traceback
|
|
346
|
+
traceback.print_exc()
|
orbit/plugin/memory_estimator.py
CHANGED
orbit/plugin/warmup.py
CHANGED
orbit/utils/__init__.py
CHANGED
|
@@ -14,7 +14,7 @@ from .freeze import (
|
|
|
14
14
|
set_trainable,
|
|
15
15
|
freeze_layers,
|
|
16
16
|
unfreeze_layers,
|
|
17
|
-
|
|
17
|
+
count_params, ParamStats,
|
|
18
18
|
)
|
|
19
19
|
from .seed import (
|
|
20
20
|
seed_everything,
|
|
@@ -27,3 +27,26 @@ from .mask import (
|
|
|
27
27
|
make_causal_mask,
|
|
28
28
|
make_sliding_window_mask
|
|
29
29
|
)
|
|
30
|
+
from .layer_io import (
|
|
31
|
+
get_model_by_name,
|
|
32
|
+
save_layer, load_layer,
|
|
33
|
+
save_model, load_model
|
|
34
|
+
)
|
|
35
|
+
from .lora import (
|
|
36
|
+
save_lora, load_lora, inject_lora, inject_lora_file,
|
|
37
|
+
merge_lora, unmerge_lora,
|
|
38
|
+
freeze_backbone_only,
|
|
39
|
+
LoRADiagnoser
|
|
40
|
+
)
|
|
41
|
+
from .sft import (
|
|
42
|
+
build_sft, train_sft
|
|
43
|
+
)
|
|
44
|
+
from .cuda import (
|
|
45
|
+
cuda_alloc
|
|
46
|
+
)
|
|
47
|
+
from .image import (
|
|
48
|
+
split_to_patches, reconstruct_from_patches, pad_to_patch_size
|
|
49
|
+
)
|
|
50
|
+
from .moe import (
|
|
51
|
+
set_moe_training_mode
|
|
52
|
+
)
|
orbit/utils/cuda.py
ADDED
orbit/utils/freeze.py
CHANGED
|
@@ -1,10 +1,20 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
|
-
from typing import Union, List, Optional, Iterable
|
|
3
|
+
from typing import Union, List, Optional, Iterable, Type
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class ParamStats:
|
|
8
|
+
count: int
|
|
9
|
+
params: List[torch.Tensor]
|
|
10
|
+
|
|
11
|
+
def __iter__(self):
|
|
12
|
+
return iter(self.params)
|
|
4
13
|
|
|
5
14
|
def set_trainable(
|
|
6
15
|
model: nn.Module,
|
|
7
16
|
targets: Optional[Union[str, List[str]]] = None,
|
|
17
|
+
target_classes: Optional[Union[Type[nn.Module], List[Type[nn.Module]]]] = None,
|
|
8
18
|
trainable: bool = False
|
|
9
19
|
) -> None:
|
|
10
20
|
'''设置模型参数的 requires_grad 属性,用于冻结或解冻层。
|
|
@@ -12,48 +22,82 @@ def set_trainable(
|
|
|
12
22
|
Args:
|
|
13
23
|
model (nn.Module): 目标模型。
|
|
14
24
|
targets (str or List[str], optional): 要操作的层名称或参数名称模式。
|
|
15
|
-
- 如果为 None,则操作模型的所有参数。
|
|
25
|
+
- 如果为 None 且 target_classes 也为 None,则操作模型的所有参数。
|
|
16
26
|
- 如果为 str,则操作名称中包含该字符串的所有参数。
|
|
17
27
|
- 如果为 List[str],则操作名称中包含列表中任意字符串的所有参数。
|
|
28
|
+
target_classes (Type[nn.Module] or List[Type[nn.Module]], optional): 要操作的模块类。
|
|
29
|
+
- 如果指定,则操作属于该类(或其子类)的所有模块的参数。
|
|
18
30
|
trainable (bool): 是否可训练 (True 为解冻, False 为冻结)。
|
|
19
31
|
'''
|
|
20
|
-
if targets is None:
|
|
32
|
+
if targets is None and target_classes is None:
|
|
21
33
|
for param in model.parameters():
|
|
22
34
|
param.requires_grad = trainable
|
|
23
|
-
|
|
35
|
+
return
|
|
36
|
+
|
|
37
|
+
if targets is not None:
|
|
24
38
|
if isinstance(targets, str):
|
|
25
39
|
targets = [targets]
|
|
26
40
|
|
|
27
41
|
for name, param in model.named_parameters():
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
42
|
+
if any(t in name for t in targets): param.requires_grad = trainable
|
|
43
|
+
|
|
44
|
+
if target_classes is not None:
|
|
45
|
+
if not isinstance(target_classes, (list, tuple)): target_classes = [target_classes]
|
|
46
|
+
|
|
47
|
+
target_classes = tuple(target_classes)
|
|
31
48
|
|
|
32
|
-
|
|
49
|
+
for module in model.modules():
|
|
50
|
+
if isinstance(module, target_classes):
|
|
51
|
+
for param in module.parameters(): param.requires_grad = trainable
|
|
52
|
+
|
|
53
|
+
def freeze_layers(
|
|
54
|
+
model: nn.Module,
|
|
55
|
+
targets: Optional[Union[str, List[str]]] = None,
|
|
56
|
+
target_classes: Optional[Union[Type[nn.Module], List[Type[nn.Module]]]] = None
|
|
57
|
+
) -> None:
|
|
33
58
|
'''冻结模型指定层或所有层 (requires_grad=False)。
|
|
34
59
|
|
|
35
60
|
Args:
|
|
36
61
|
model (nn.Module): 目标模型。
|
|
37
|
-
targets (str or List[str], optional):
|
|
62
|
+
targets (str or List[str], optional): 要冻结的层名称模式。
|
|
63
|
+
target_classes (Type[nn.Module] or List[Type[nn.Module]], optional): 要冻结的模块类。
|
|
64
|
+
如果不指定 targets 和 target_classes,则冻结整个模型。
|
|
38
65
|
'''
|
|
39
|
-
set_trainable(model, targets, trainable=False)
|
|
66
|
+
set_trainable(model, targets, target_classes, trainable=False)
|
|
40
67
|
|
|
41
|
-
def unfreeze_layers(
|
|
68
|
+
def unfreeze_layers(
|
|
69
|
+
model: nn.Module,
|
|
70
|
+
targets: Optional[Union[str, List[str]]] = None,
|
|
71
|
+
target_classes: Optional[Union[Type[nn.Module], List[Type[nn.Module]]]] = None
|
|
72
|
+
) -> None:
|
|
42
73
|
'''解冻模型指定层或所有层 (requires_grad=True)。
|
|
43
74
|
|
|
44
75
|
Args:
|
|
45
76
|
model (nn.Module): 目标模型。
|
|
46
|
-
targets (str or List[str], optional):
|
|
77
|
+
targets (str or List[str], optional): 要解冻的层名称模式。
|
|
78
|
+
target_classes (Type[nn.Module] or List[Type[nn.Module]], optional): 要解冻的模块类。
|
|
79
|
+
如果不指定 targets 和 target_classes,则解冻整个模型。
|
|
47
80
|
'''
|
|
48
|
-
set_trainable(model, targets, trainable=True)
|
|
81
|
+
set_trainable(model, targets, target_classes, trainable=True)
|
|
49
82
|
|
|
50
|
-
def
|
|
51
|
-
'''
|
|
83
|
+
def count_params(model: nn.Module, mode: str = 'trainable') -> ParamStats:
|
|
84
|
+
'''统计模型参数数量并获取参数列表。
|
|
52
85
|
|
|
53
86
|
Args:
|
|
54
87
|
model (nn.Module): 目标模型。
|
|
88
|
+
mode (str): 统计模式,可选 'trainable', 'frozen', 'all'。默认为 'trainable'。
|
|
55
89
|
|
|
56
90
|
Returns:
|
|
57
|
-
|
|
91
|
+
ParamStats: 包含参数总数(count)和参数列表(params)的数据类。
|
|
58
92
|
'''
|
|
59
|
-
|
|
93
|
+
if mode == 'trainable':
|
|
94
|
+
params = [p for p in model.parameters() if p.requires_grad]
|
|
95
|
+
elif mode == 'frozen':
|
|
96
|
+
params = [p for p in model.parameters() if not p.requires_grad]
|
|
97
|
+
elif mode == 'all':
|
|
98
|
+
params = list(model.parameters())
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError(f"Invalid mode: {mode}. Must be one of 'trainable', 'frozen', 'all'.")
|
|
101
|
+
|
|
102
|
+
count = sum(p.numel() for p in params)
|
|
103
|
+
return ParamStats(count=count, params=params)
|