orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -1
- orbit/callback.py +4 -3
- orbit/dataset/__init__.py +1 -0
- orbit/dataset/cogn.py +138 -0
- orbit/dataset/data/cogn_en.jsonl +45 -0
- orbit/dataset/data/cogn_zh.jsonl +113 -0
- orbit/engine.py +210 -146
- orbit/kit/__init__.py +2 -0
- orbit/kit/interface.py +154 -0
- orbit/kit/wrapper.py +157 -0
- orbit/model/__init__.py +5 -0
- orbit/model/base.py +125 -0
- orbit/model/block/__init__.py +34 -0
- orbit/model/block/attention.py +265 -0
- orbit/model/block/bio.py +537 -0
- orbit/model/block/codebook.py +122 -0
- orbit/model/block/conv.py +505 -0
- orbit/model/block/embedding.py +252 -0
- orbit/model/block/film.py +176 -0
- orbit/model/block/fusion.py +335 -0
- orbit/model/block/gate.py +334 -0
- orbit/model/block/lora.py +776 -0
- orbit/model/block/mlp.py +68 -0
- orbit/model/block/moe.py +94 -0
- orbit/model/block/tcn.py +99 -0
- orbit/model/config.py +62 -0
- orbit/model/kit/__init__.py +6 -0
- orbit/model/kit/discriminator.py +46 -0
- orbit/model/kit/losses.py +193 -0
- orbit/model/motif/__init__.py +0 -0
- orbit/model/motif/vision/__init__.py +0 -0
- orbit/model/motif/vision/v1.py +645 -0
- orbit/model/registry.py +53 -0
- orbit/optim/__init__.py +2 -2
- orbit/optim/sam.py +10 -3
- orbit/plugin/__init__.py +12 -8
- orbit/plugin/board.py +1 -2
- orbit/plugin/checkpoint.py +137 -62
- orbit/plugin/classification.py +2 -2
- orbit/plugin/display_model.py +1 -2
- orbit/plugin/early_stopping.py +1 -2
- orbit/plugin/ema.py +1 -2
- orbit/plugin/gradient_accumulation.py +1 -2
- orbit/plugin/lora.py +346 -0
- orbit/plugin/memory_estimator.py +1 -2
- orbit/plugin/warmup.py +1 -2
- orbit/utils/__init__.py +24 -1
- orbit/utils/cuda.py +10 -0
- orbit/utils/freeze.py +61 -17
- orbit/utils/image.py +164 -0
- orbit/utils/initialization.py +184 -94
- orbit/utils/layer_io.py +66 -7
- orbit/utils/lora.py +480 -0
- orbit/utils/moe.py +55 -0
- orbit/utils/seed.py +3 -19
- orbit/utils/sft.py +93 -0
- orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
- orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
- orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
orbit/engine.py
CHANGED
|
@@ -8,20 +8,21 @@ from typing import Any, List, Optional, Union, Dict, Tuple
|
|
|
8
8
|
try: from torch.utils.tensorboard import SummaryWriter
|
|
9
9
|
except: pass
|
|
10
10
|
|
|
11
|
-
from rich.progress
|
|
12
|
-
from rich.console
|
|
11
|
+
from rich.progress import Progress, TextColumn, BarColumn, TimeRemainingColumn, MofNCompleteColumn
|
|
12
|
+
from rich.console import Console
|
|
13
|
+
from accelerate import Accelerator
|
|
13
14
|
|
|
14
15
|
from orbit.callback import Callback, Forward, Event
|
|
15
|
-
from orbit.plugin
|
|
16
|
-
from orbit.
|
|
17
|
-
|
|
16
|
+
from orbit.plugin import Checkpoint, Board, ModelSummary
|
|
17
|
+
from orbit.utils import load_model
|
|
18
|
+
|
|
18
19
|
|
|
19
20
|
class Engine:
|
|
20
21
|
'''训练循环控制器,负责协调模型训练、验证及回调事件。
|
|
21
22
|
|
|
22
|
-
Engine 封装了 PyTorch
|
|
23
|
-
|
|
24
|
-
TensorBoard 可视化等功能。
|
|
23
|
+
Engine 封装了 PyTorch 的训练循环,并深度集成了 Accelerate 库。
|
|
24
|
+
它提供了开箱即用的分布式训练支持(DDP, FSDP, DeepSpeed 等)、
|
|
25
|
+
自动混合精度(AMP)、梯度裁剪、梯度累积、Checkpoint 管理以及 TensorBoard 可视化等功能。
|
|
25
26
|
'''
|
|
26
27
|
|
|
27
28
|
class _OutLogs:
|
|
@@ -39,63 +40,70 @@ class Engine:
|
|
|
39
40
|
model: nn.Module,
|
|
40
41
|
optimizer: torch.optim.Optimizer = None,
|
|
41
42
|
criterion: nn.Module = None,
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
use_amp: bool = False,
|
|
43
|
+
accelerator: Optional[Accelerator] = None,
|
|
44
|
+
mixed_precision: str = 'no', # 'no', 'fp16', 'bf16'
|
|
45
45
|
grad_clip_norm: float = None,
|
|
46
46
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
|
47
47
|
plugins: List[Callback] = None,
|
|
48
48
|
forward_step: Optional[Forward] = None,
|
|
49
49
|
checkpoint_dir: str = None,
|
|
50
50
|
console: Console = None,
|
|
51
|
+
**accelerator_kwargs
|
|
51
52
|
):
|
|
52
53
|
'''初始化 Engine 实例。
|
|
53
54
|
|
|
54
55
|
Args:
|
|
55
56
|
model (nn.Module): 要训练的 PyTorch 模型。
|
|
56
|
-
optimizer (torch.optim.Optimizer, optional): 优化器。如果为 None
|
|
57
|
+
optimizer (torch.optim.Optimizer, optional): 优化器。如果为 None,则需要在其他地方手动处理。
|
|
57
58
|
criterion (nn.Module, optional): 损失函数。如果为 None,则假设模型输出包含 loss 或自定义 loss 计算。
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
59
|
+
accelerator (Optional[Accelerator], optional): 预初始化的 Accelerator 实例。
|
|
60
|
+
如果为 None,将使用 mixed_precision 和 accelerator_kwargs 创建一个新的实例。
|
|
61
|
+
mixed_precision (str, optional): 混合精度模式。可选值为 'no', 'fp16', 'bf16'。默认为 'no'。
|
|
61
62
|
grad_clip_norm (float, optional): 梯度裁剪的范数阈值。如果为 None,则不进行梯度裁剪。
|
|
62
63
|
scheduler (Optional[torch.optim.lr_scheduler._LRScheduler], optional): 学习率调度器。
|
|
63
64
|
plugins (List[Callback], optional): 初始化时要挂载的回调插件列表。
|
|
64
65
|
forward_step (Optional[Forward], optional): 自定义前向传播和 Loss 计算逻辑的实现。
|
|
65
66
|
checkpoint_dir (str, optional): 快速设置 Checkpoint 保存目录的快捷参数。
|
|
66
67
|
console (Console, optional): 用于输出日志的 Rich Console 实例。如果为 None,则创建一个新的。
|
|
68
|
+
**accelerator_kwargs: 传递给 Accelerator 构造函数的其他参数(如 cpu, split_batches 等)。
|
|
67
69
|
'''
|
|
68
|
-
# ---
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
if self.device_ids and len(self.device_ids) > 0:
|
|
72
|
-
self.device = torch.device(f"cuda:{self.device_ids[0]}")
|
|
73
|
-
elif device is None:
|
|
74
|
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
70
|
+
# --- Accelerator 初始化 ---
|
|
71
|
+
if accelerator:
|
|
72
|
+
self.accelerator = accelerator
|
|
75
73
|
else:
|
|
76
|
-
self.
|
|
77
|
-
|
|
78
|
-
#
|
|
79
|
-
model = model
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
74
|
+
self.accelerator = Accelerator(mixed_precision=mixed_precision, **accelerator_kwargs)
|
|
75
|
+
|
|
76
|
+
# --- 基础组件 ---
|
|
77
|
+
self.model = model
|
|
78
|
+
self.optimizer = optimizer
|
|
79
|
+
self.criterion = criterion
|
|
80
|
+
self.scheduler = scheduler
|
|
81
|
+
|
|
82
|
+
# 准备模型、优化器和调度器
|
|
83
|
+
# 注意:DataLoader 将在 run/train/eval 中准备
|
|
84
|
+
if self.optimizer:
|
|
85
|
+
if self.scheduler:
|
|
86
|
+
self.model, self.optimizer, self.scheduler = self.accelerator.prepare(
|
|
87
|
+
self.model, self.optimizer, self.scheduler
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
self.model, self.optimizer = self.accelerator.prepare(
|
|
91
|
+
self.model, self.optimizer
|
|
92
|
+
)
|
|
84
93
|
else:
|
|
85
|
-
self.model = model
|
|
94
|
+
self.model = self.accelerator.prepare(self.model)
|
|
86
95
|
|
|
87
96
|
self.model_name = self.unwrap_model().__class__.__name__
|
|
88
|
-
self.optimizer = optimizer
|
|
89
|
-
self.criterion = criterion
|
|
90
97
|
|
|
91
98
|
# --- 训练配置 ---
|
|
92
|
-
self.use_amp = use_amp
|
|
93
99
|
self.grad_clip_norm = grad_clip_norm
|
|
94
|
-
self.scheduler = scheduler
|
|
95
100
|
self.forward_step = forward_step
|
|
96
|
-
self.scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
|
|
97
101
|
|
|
98
102
|
# --- 交互与回调 ---
|
|
103
|
+
self.prepare = self.accelerator.prepare
|
|
104
|
+
self.backward = self.accelerator.backward
|
|
105
|
+
self.autocast = self.accelerator.autocast
|
|
106
|
+
|
|
99
107
|
self.console = console if console else Console()
|
|
100
108
|
self.out_logs = self._OutLogs(self)
|
|
101
109
|
self.writer: Optional[SummaryWriter] = None
|
|
@@ -145,6 +153,21 @@ class Engine:
|
|
|
145
153
|
# 触发初始化回调
|
|
146
154
|
self._fire_event("on_init")
|
|
147
155
|
|
|
156
|
+
@property
|
|
157
|
+
def device(self):
|
|
158
|
+
'''获取当前使用的设备。'''
|
|
159
|
+
return self.accelerator.device
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def use_amp(self) -> bool:
|
|
163
|
+
'''是否启用了混合精度训练。'''
|
|
164
|
+
return self.accelerator.mixed_precision != 'no'
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def scaler(self):
|
|
168
|
+
'''获取 GradScaler 对象(如果存在)。'''
|
|
169
|
+
return self.accelerator.scaler
|
|
170
|
+
|
|
148
171
|
def stop(self, source: str = "User", reason: str = "Unknown"):
|
|
149
172
|
'''请求停止训练。
|
|
150
173
|
|
|
@@ -157,10 +180,15 @@ class Engine:
|
|
|
157
180
|
self.stop_reason = reason
|
|
158
181
|
|
|
159
182
|
def unwrap_model(self) -> nn.Module:
|
|
160
|
-
'''
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
183
|
+
'''获取原始模型对象。
|
|
184
|
+
|
|
185
|
+
使用 accelerator.unwrap_model 去除 DDP/FSDP/DeepSpeed 等分布式包装,
|
|
186
|
+
返回原始的 nn.Module 实例。
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
nn.Module: 原始模型对象。
|
|
190
|
+
'''
|
|
191
|
+
return self.accelerator.unwrap_model(self.model)
|
|
164
192
|
|
|
165
193
|
def is_in_warmup(self) -> bool:
|
|
166
194
|
'''检查当前是否处于 Warmup 阶段。
|
|
@@ -179,20 +207,25 @@ class Engine:
|
|
|
179
207
|
def init_board(self, log_dir: str = 'runs') -> 'Engine':
|
|
180
208
|
'''初始化 TensorBoard 可视化插件 (Board)。
|
|
181
209
|
|
|
210
|
+
注意:此操作仅在主进程 (Main Process) 中执行,以避免多进程写入冲突。
|
|
211
|
+
|
|
182
212
|
Args:
|
|
183
213
|
log_dir (str, optional): TensorBoard 日志保存目录。默认为 'runs'。
|
|
184
214
|
|
|
185
215
|
Returns:
|
|
186
216
|
Engine: 返回 Engine 实例自身以支持链式调用。
|
|
187
217
|
'''
|
|
188
|
-
|
|
189
|
-
self.
|
|
218
|
+
# 仅在主进程初始化 Board
|
|
219
|
+
if self.accelerator.is_main_process:
|
|
220
|
+
board = Board(name=self.model_name, log_dir=log_dir)
|
|
221
|
+
self.attach(board, init=True)
|
|
190
222
|
return self
|
|
191
223
|
|
|
192
224
|
def set_checkpoint(self, dir: str, name: Optional[str] = None, **kwargs) -> 'Engine':
|
|
193
225
|
'''配置 Checkpoint 插件。
|
|
194
226
|
|
|
195
227
|
如果已存在 Checkpoint 插件,将被新配置替换。
|
|
228
|
+
注意:此操作仅在主进程 (Main Process) 中执行。
|
|
196
229
|
|
|
197
230
|
Args:
|
|
198
231
|
dir (str): 模型保存目录。
|
|
@@ -202,36 +235,38 @@ class Engine:
|
|
|
202
235
|
Returns:
|
|
203
236
|
Engine: 返回 Engine 实例自身以支持链式调用。
|
|
204
237
|
'''
|
|
238
|
+
# 仅在主进程配置 Checkpoint
|
|
239
|
+
if not self.accelerator.is_main_process:
|
|
240
|
+
return self
|
|
241
|
+
|
|
205
242
|
if name is None:
|
|
206
243
|
name = self.model_name
|
|
207
244
|
|
|
208
|
-
# 1. 移除旧的 Checkpoint 插件 (如果存在)
|
|
209
245
|
self.plugins = [p for p in self.plugins if not isinstance(p, Checkpoint)]
|
|
210
246
|
|
|
211
|
-
# 2. 创建新插件
|
|
212
247
|
ckpt = Checkpoint(name=name, path=dir, **kwargs)
|
|
213
|
-
|
|
214
|
-
# 3. 调用 ckpt.on_init(event)
|
|
215
|
-
# 注意:这里我们手动构造 Event,因为此时可能不在 run 循环中
|
|
216
248
|
ckpt.on_init(Event(engine=self, name="on_init"))
|
|
217
|
-
|
|
218
|
-
# 4. 挂载
|
|
219
249
|
self.attach(ckpt)
|
|
220
250
|
return self
|
|
221
251
|
|
|
222
252
|
def _print_edge(self, top=True):
|
|
253
|
+
if not self.accelerator.is_main_process: return
|
|
223
254
|
char = '┬' if top else '┴'
|
|
224
255
|
self.console.print(' ' + '─' * 15 + char + '─' * 35)
|
|
225
256
|
|
|
226
257
|
def print(self, *args, plugin: Optional[str] = None, **kwargs):
|
|
227
258
|
'''统一日志打印方法,支持插件前缀。
|
|
228
259
|
|
|
260
|
+
注意:此方法仅在主进程 (Main Process) 中输出日志,其他进程的调用将被忽略。
|
|
261
|
+
|
|
229
262
|
Args:
|
|
230
263
|
*args: 要打印的内容。
|
|
231
264
|
plugin (str, optional): 插件名称。如果提供,将以固定宽度和特定颜色打印前缀,
|
|
232
265
|
用于区分不同来源的日志。
|
|
233
266
|
**kwargs: 传递给 console.print 的其他参数。
|
|
234
267
|
'''
|
|
268
|
+
if not self.accelerator.is_main_process: return
|
|
269
|
+
|
|
235
270
|
if plugin:
|
|
236
271
|
# 宽度 15, 右对齐, 青色加粗
|
|
237
272
|
prefix = f"[[bold cyan]{plugin:>15}[/]] "
|
|
@@ -273,9 +308,6 @@ class Engine:
|
|
|
273
308
|
for cb in self.plugins:
|
|
274
309
|
method = getattr(cb, event_name, None)
|
|
275
310
|
if method:
|
|
276
|
-
# [修改] 移除 try-except pass。
|
|
277
|
-
# 我们需要看到 Callback 里的错误,否则调试是地狱。
|
|
278
|
-
# 如果一定要防御性编程,可以使用 console.print_exception()
|
|
279
311
|
method(event)
|
|
280
312
|
|
|
281
313
|
def _process_batch_data(self, batch_data: Any):
|
|
@@ -283,12 +315,30 @@ class Engine:
|
|
|
283
315
|
|
|
284
316
|
支持 Tensor, List[Tensor], Dict[str, Tensor] 等常见格式。
|
|
285
317
|
自动解析并设置 self.data 和 self.target。
|
|
318
|
+
|
|
319
|
+
注意:如果 DataLoader 已经被 accelerator.prepare 处理,数据通常已经在正确设备上。
|
|
320
|
+
此方法包含额外的检查以确保数据位于 self.device 上。
|
|
286
321
|
|
|
287
322
|
Args:
|
|
288
323
|
batch_data (Any): DataLoader 产生的一个 Batch 数据。
|
|
289
324
|
'''
|
|
325
|
+
def to_device(data):
|
|
326
|
+
if isinstance(data, torch.Tensor):
|
|
327
|
+
return data.to(self.device)
|
|
328
|
+
elif isinstance(data, (list, tuple)):
|
|
329
|
+
return [to_device(x) for x in data]
|
|
330
|
+
elif isinstance(data, dict):
|
|
331
|
+
return {k: to_device(v) for k, v in data.items()}
|
|
332
|
+
return data
|
|
333
|
+
|
|
334
|
+
# 如果数据不在当前设备上,移动它
|
|
335
|
+
# 注意:accelerator.prepare 后的 DataLoader 通常已经处理了设备放置
|
|
336
|
+
|
|
290
337
|
if isinstance(batch_data, (list, tuple)):
|
|
291
|
-
|
|
338
|
+
# 简单检查第一个元素
|
|
339
|
+
if len(batch_data) > 0 and isinstance(batch_data[0], torch.Tensor) and batch_data[0].device != self.device:
|
|
340
|
+
batch_data = to_device(batch_data)
|
|
341
|
+
|
|
292
342
|
if len(batch_data) == 2:
|
|
293
343
|
self.data, self.target = batch_data
|
|
294
344
|
elif len(batch_data) == 1:
|
|
@@ -298,36 +348,44 @@ class Engine:
|
|
|
298
348
|
self.data = batch_data[:-1]
|
|
299
349
|
self.target = batch_data[-1]
|
|
300
350
|
elif isinstance(batch_data, dict):
|
|
301
|
-
|
|
351
|
+
# 简单检查第一个值
|
|
352
|
+
first_val = next(iter(batch_data.values()), None)
|
|
353
|
+
if isinstance(first_val, torch.Tensor) and first_val.device != self.device:
|
|
354
|
+
batch_data = to_device(batch_data)
|
|
355
|
+
|
|
356
|
+
self.data = batch_data
|
|
302
357
|
self.target = None
|
|
303
358
|
else:
|
|
304
|
-
|
|
359
|
+
if isinstance(batch_data, torch.Tensor) and batch_data.device != self.device:
|
|
360
|
+
batch_data = batch_data.to(self.device)
|
|
361
|
+
self.data = batch_data
|
|
305
362
|
self.target = None
|
|
306
363
|
|
|
307
364
|
def update(self, loss: torch.Tensor):
|
|
308
365
|
'''执行反向传播及参数更新。
|
|
309
366
|
|
|
367
|
+
使用 accelerator.backward 处理反向传播,支持自动混合精度和梯度累积。
|
|
368
|
+
同时支持 SAM (Sharpness-Aware Minimization) 优化器的两步更新逻辑。
|
|
369
|
+
|
|
310
370
|
Args:
|
|
311
371
|
loss (torch.Tensor): 当前 Step 的 Loss。
|
|
312
372
|
'''
|
|
313
373
|
if self.is_epoch_end: return
|
|
314
374
|
|
|
315
|
-
# 保存原始 Loss 用于日志 (因为 SAM 需要第二次 forward 会覆盖 self.loss)
|
|
316
375
|
original_loss = loss
|
|
317
376
|
self.loss = loss
|
|
318
377
|
|
|
319
|
-
|
|
378
|
+
if self.optimizer is None:
|
|
379
|
+
self.global_step += 1
|
|
380
|
+
return
|
|
381
|
+
|
|
382
|
+
# 梯度累积处理
|
|
320
383
|
backward_loss = loss
|
|
321
384
|
if self.accumulation_steps > 1:
|
|
322
385
|
backward_loss = loss / self.accumulation_steps
|
|
323
386
|
|
|
324
|
-
|
|
325
|
-
if self.use_amp and self.scaler:
|
|
326
|
-
self.scaler.scale(backward_loss).backward()
|
|
327
|
-
else:
|
|
328
|
-
backward_loss.backward()
|
|
387
|
+
self.accelerator.backward(backward_loss)
|
|
329
388
|
|
|
330
|
-
# 3. Optimizer Step (仅在累积步数到达或 Epoch 结束时执行)
|
|
331
389
|
if (self.batch_idx + 1) % self.accumulation_steps == 0 or self.is_last_batch:
|
|
332
390
|
|
|
333
391
|
# 检测是否为 SAM 优化器 (Duck Typing)
|
|
@@ -335,73 +393,43 @@ class Engine:
|
|
|
335
393
|
|
|
336
394
|
if is_sam:
|
|
337
395
|
# --- SAM Optimizer Logic ---
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
self.
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
# 3.5. Second Backward: 计算 w + e 处的梯度
|
|
356
|
-
self.scaler.scale(self.loss).backward()
|
|
357
|
-
|
|
358
|
-
# 3.6. SAM Second Step: 恢复 w, 并更新 w
|
|
359
|
-
# 需要再次 unscale 第二次计算的梯度
|
|
360
|
-
self.scaler.unscale_(self.optimizer)
|
|
361
|
-
if self.grad_clip_norm:
|
|
362
|
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
363
|
-
|
|
364
|
-
self.optimizer.second_step(zero_grad=True)
|
|
365
|
-
self.scaler.update()
|
|
366
|
-
|
|
367
|
-
else:
|
|
368
|
-
# 普通模式下的 SAM 处理
|
|
369
|
-
if self.grad_clip_norm:
|
|
370
|
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
371
|
-
|
|
372
|
-
self.optimizer.first_step(zero_grad=True)
|
|
373
|
-
|
|
374
|
-
self._forward_pass()
|
|
375
|
-
self.loss.backward()
|
|
376
|
-
|
|
377
|
-
if self.grad_clip_norm:
|
|
378
|
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
379
|
-
|
|
380
|
-
self.optimizer.second_step(zero_grad=True)
|
|
396
|
+
# SAM 需要两次 backward,accelerate 支持多次 backward
|
|
397
|
+
|
|
398
|
+
if self.grad_clip_norm:
|
|
399
|
+
self.accelerator.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
400
|
+
|
|
401
|
+
# First Step: Compute e_w
|
|
402
|
+
self.optimizer.first_step(zero_grad=True)
|
|
403
|
+
|
|
404
|
+
# Second Forward-Backward
|
|
405
|
+
self._forward_pass()
|
|
406
|
+
self.accelerator.backward(self.loss)
|
|
407
|
+
|
|
408
|
+
if self.grad_clip_norm:
|
|
409
|
+
self.accelerator.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
410
|
+
|
|
411
|
+
# Second Step: Update weights
|
|
412
|
+
self.optimizer.second_step(zero_grad=True)
|
|
381
413
|
|
|
382
|
-
# 恢复原始 Loss 以保证日志记录的一致性
|
|
383
414
|
self.loss = original_loss
|
|
384
415
|
|
|
385
416
|
else:
|
|
386
417
|
# --- Standard Optimizer Logic ---
|
|
387
|
-
if self.
|
|
388
|
-
|
|
389
|
-
self.scaler.unscale_(self.optimizer)
|
|
390
|
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
391
|
-
self.scaler.step(self.optimizer)
|
|
392
|
-
self.scaler.update()
|
|
393
|
-
else:
|
|
394
|
-
if self.grad_clip_norm:
|
|
395
|
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
396
|
-
self.optimizer.step()
|
|
418
|
+
if self.grad_clip_norm:
|
|
419
|
+
self.accelerator.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
397
420
|
|
|
421
|
+
self.optimizer.step()
|
|
398
422
|
self.optimizer.zero_grad()
|
|
399
423
|
|
|
400
424
|
self.global_step += 1
|
|
401
425
|
|
|
402
426
|
def _forward_pass(self) -> torch.Tensor:
|
|
403
|
-
'''执行前向传播并计算 Loss (内部方法)。
|
|
404
|
-
|
|
427
|
+
'''执行前向传播并计算 Loss (内部方法)。
|
|
428
|
+
|
|
429
|
+
使用 accelerator.autocast() 上下文管理器自动处理混合精度。
|
|
430
|
+
'''
|
|
431
|
+
# 使用 accelerator.autocast() 处理混合精度
|
|
432
|
+
with self.accelerator.autocast():
|
|
405
433
|
if self.forward_step:
|
|
406
434
|
self.loss = self.forward_step.forward(self, self.data, self.target)
|
|
407
435
|
else:
|
|
@@ -446,6 +474,8 @@ class Engine:
|
|
|
446
474
|
):
|
|
447
475
|
'''启动训练循环。
|
|
448
476
|
|
|
477
|
+
此方法会自动使用 accelerator.prepare 包装数据加载器以支持分布式训练。
|
|
478
|
+
|
|
449
479
|
Args:
|
|
450
480
|
train_loader (Any): 训练数据加载器 (通常是 torch.utils.data.DataLoader)。
|
|
451
481
|
val_loader (Optional[Any], optional): 验证数据加载器。
|
|
@@ -454,6 +484,12 @@ class Engine:
|
|
|
454
484
|
如果为 None,则从 0 开始。用于断点续训。
|
|
455
485
|
with_eval (bool, optional): 是否在每个 Epoch 结束后执行验证。默认为 True。
|
|
456
486
|
'''
|
|
487
|
+
# 准备 DataLoaders
|
|
488
|
+
if val_loader:
|
|
489
|
+
train_loader, val_loader = self.accelerator.prepare(train_loader, val_loader)
|
|
490
|
+
else:
|
|
491
|
+
train_loader = self.accelerator.prepare(train_loader)
|
|
492
|
+
|
|
457
493
|
self.train_loader = train_loader
|
|
458
494
|
self.val_loader = val_loader
|
|
459
495
|
self.num_epochs = num_epochs
|
|
@@ -525,19 +561,24 @@ class Engine:
|
|
|
525
561
|
self._fire_event("on_train_end")
|
|
526
562
|
|
|
527
563
|
def _train_epoch_iterator(self, loader: Any, total_steps: Optional[int] = None, prefix: str = "Train", color: str = "blue"):
|
|
528
|
-
'''生成器:执行单个 Epoch 的训练循环。
|
|
564
|
+
'''生成器:执行单个 Epoch 的训练循环。
|
|
565
|
+
|
|
566
|
+
注意:进度条仅在主进程中显示。
|
|
567
|
+
'''
|
|
529
568
|
self.model.train()
|
|
530
569
|
self.is_epoch_end = False
|
|
570
|
+
torch.cuda.empty_cache()
|
|
531
571
|
|
|
532
|
-
# 尝试获取真实的 loader 长度
|
|
533
572
|
try:
|
|
534
573
|
real_len = len(loader)
|
|
535
574
|
except:
|
|
536
575
|
real_len = None
|
|
537
576
|
|
|
538
|
-
# 确定进度条的总步数
|
|
539
577
|
num_batches = total_steps if total_steps is not None else real_len
|
|
540
578
|
|
|
579
|
+
# 仅主进程显示进度条
|
|
580
|
+
disable_progress = not self.accelerator.is_main_process
|
|
581
|
+
|
|
541
582
|
with Progress(
|
|
542
583
|
TextColumn(f"[{color}]{prefix}"),
|
|
543
584
|
TextColumn("[progress.description]{task.description}"),
|
|
@@ -545,21 +586,21 @@ class Engine:
|
|
|
545
586
|
MofNCompleteColumn(),
|
|
546
587
|
TimeRemainingColumn(),
|
|
547
588
|
console=self.console,
|
|
548
|
-
transient=True
|
|
589
|
+
transient=True,
|
|
590
|
+
disable=disable_progress
|
|
549
591
|
) as progress:
|
|
550
592
|
|
|
551
593
|
task = progress.add_task(f"[Ep {self.epoch+1}/{self.num_epochs}]", total=num_batches)
|
|
552
594
|
|
|
553
595
|
for batch_idx, batch_data in enumerate(loader):
|
|
554
|
-
# 断点续训:跳过已训练的 Batch
|
|
555
596
|
if self.epoch == self.start_epoch and batch_idx <= self.start_batch_idx:
|
|
556
|
-
|
|
597
|
+
if not disable_progress:
|
|
598
|
+
progress.update(task, advance=1, description=f"[dim]Skipping batch {batch_idx}...[/]")
|
|
557
599
|
continue
|
|
558
600
|
|
|
559
601
|
self.batch_idx = batch_idx
|
|
560
602
|
self.is_first_batch = (batch_idx == 0)
|
|
561
603
|
|
|
562
|
-
# 优先使用真实长度判断 is_last_batch
|
|
563
604
|
if real_len is not None:
|
|
564
605
|
self.is_last_batch = (batch_idx == real_len - 1)
|
|
565
606
|
elif num_batches is not None:
|
|
@@ -570,25 +611,24 @@ class Engine:
|
|
|
570
611
|
self._process_batch_data(batch_data)
|
|
571
612
|
self._fire_event("on_batch_start")
|
|
572
613
|
|
|
573
|
-
# Yield self to allow external control (e.g., engine.step())
|
|
574
614
|
yield self
|
|
575
615
|
|
|
576
|
-
# 更新进度条
|
|
577
616
|
loss_val = self.loss.item() if self.loss is not None else 0.0
|
|
578
617
|
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
618
|
+
if not disable_progress:
|
|
619
|
+
lr_str = ""
|
|
620
|
+
if self.optimizer:
|
|
621
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
622
|
+
if current_lr < 1e-6:
|
|
623
|
+
lr_str = f" LR: {current_lr:.2e}"
|
|
624
|
+
else:
|
|
625
|
+
lr_str = f" LR: {current_lr:.6f}"
|
|
626
|
+
|
|
627
|
+
if self.is_in_warmup():
|
|
628
|
+
lr_str += " [Warmup]"
|
|
589
629
|
|
|
590
|
-
|
|
591
|
-
|
|
630
|
+
logs = f"Loss: {loss_val:.4f}{lr_str} [Ep {self.epoch+1}/{self.num_epochs}]"
|
|
631
|
+
progress.update(task, advance=1, description=logs)
|
|
592
632
|
|
|
593
633
|
self._fire_event("on_batch_end")
|
|
594
634
|
|
|
@@ -600,7 +640,10 @@ class Engine:
|
|
|
600
640
|
self.is_epoch_end = False
|
|
601
641
|
|
|
602
642
|
def _eval_epoch_iterator(self, loader: Any, total_steps: Optional[int] = None, prefix: str = "Eval ", color: str = "yellow"):
|
|
603
|
-
'''生成器:执行单个 Epoch 的验证/测试循环。
|
|
643
|
+
'''生成器:执行单个 Epoch 的验证/测试循环。
|
|
644
|
+
|
|
645
|
+
注意:进度条仅在主进程中显示。
|
|
646
|
+
'''
|
|
604
647
|
self.model.eval()
|
|
605
648
|
|
|
606
649
|
try:
|
|
@@ -609,6 +652,7 @@ class Engine:
|
|
|
609
652
|
real_len = None
|
|
610
653
|
|
|
611
654
|
num_batches = total_steps if total_steps is not None else real_len
|
|
655
|
+
disable_progress = not self.accelerator.is_main_process
|
|
612
656
|
|
|
613
657
|
with Progress(
|
|
614
658
|
TextColumn(f"[{color}]{prefix}"),
|
|
@@ -617,7 +661,8 @@ class Engine:
|
|
|
617
661
|
MofNCompleteColumn(),
|
|
618
662
|
TimeRemainingColumn(),
|
|
619
663
|
console=self.console,
|
|
620
|
-
transient=True
|
|
664
|
+
transient=True,
|
|
665
|
+
disable=disable_progress
|
|
621
666
|
) as progress:
|
|
622
667
|
lr_str = ""
|
|
623
668
|
if self.optimizer:
|
|
@@ -629,6 +674,7 @@ class Engine:
|
|
|
629
674
|
|
|
630
675
|
if self.is_in_warmup():
|
|
631
676
|
lr_str += " [Warmup]"
|
|
677
|
+
|
|
632
678
|
task = progress.add_task(f"{lr_str} [Ep {self.epoch+1}/{self.num_epochs}]", total=num_batches)
|
|
633
679
|
|
|
634
680
|
with torch.no_grad():
|
|
@@ -649,8 +695,9 @@ class Engine:
|
|
|
649
695
|
yield self
|
|
650
696
|
|
|
651
697
|
loss_val = self.loss.item() if self.loss is not None else 0.0
|
|
652
|
-
|
|
653
|
-
|
|
698
|
+
if not disable_progress:
|
|
699
|
+
logs = f"Loss: {loss_val:.4f}{lr_str} [Ep {self.epoch+1}/{self.num_epochs}]"
|
|
700
|
+
progress.update(task, advance=1, description=logs)
|
|
654
701
|
self._fire_event("on_batch_end")
|
|
655
702
|
|
|
656
703
|
def train(
|
|
@@ -662,12 +709,15 @@ class Engine:
|
|
|
662
709
|
):
|
|
663
710
|
'''生成器:启动训练循环,允许用户自定义 Step 逻辑。
|
|
664
711
|
|
|
712
|
+
此方法会自动使用 accelerator.prepare 包装数据加载器。
|
|
713
|
+
|
|
665
714
|
Args:
|
|
666
715
|
train_loader (Any): 训练数据加载器。
|
|
667
716
|
num_epochs (int, optional): 总训练轮数。
|
|
668
717
|
start_epoch (int, optional): 起始 Epoch。
|
|
669
718
|
total_steps (int, optional): 手动指定进度条的总步数 (用于特殊 Loader)。
|
|
670
719
|
'''
|
|
720
|
+
train_loader = self.accelerator.prepare(train_loader)
|
|
671
721
|
self.train_loader = train_loader
|
|
672
722
|
self.num_epochs = num_epochs
|
|
673
723
|
if start_epoch is not None:
|
|
@@ -682,7 +732,6 @@ class Engine:
|
|
|
682
732
|
|
|
683
733
|
self._fire_event("on_epoch_start")
|
|
684
734
|
|
|
685
|
-
# 使用生成器迭代
|
|
686
735
|
epoch_loss_sum = 0.0
|
|
687
736
|
count = 0
|
|
688
737
|
|
|
@@ -700,7 +749,6 @@ class Engine:
|
|
|
700
749
|
|
|
701
750
|
self._fire_event("on_epoch_end")
|
|
702
751
|
|
|
703
|
-
# 计算并打印 Epoch 总结
|
|
704
752
|
avg_loss = epoch_loss_sum / count if count > 0 else 0.0
|
|
705
753
|
self.metrics['train_loss'] = avg_loss
|
|
706
754
|
|
|
@@ -736,11 +784,14 @@ class Engine:
|
|
|
736
784
|
):
|
|
737
785
|
'''生成器:启动评估循环。
|
|
738
786
|
|
|
787
|
+
此方法会自动使用 accelerator.prepare 包装数据加载器。
|
|
788
|
+
|
|
739
789
|
Args:
|
|
740
790
|
val_loader (Any): 验证数据加载器。
|
|
741
791
|
total_steps (int, optional): 手动指定进度条的总步数。
|
|
742
792
|
description (str, optional): 进度条描述。
|
|
743
793
|
'''
|
|
794
|
+
val_loader = self.accelerator.prepare(val_loader)
|
|
744
795
|
self.val_loader = val_loader
|
|
745
796
|
self.state = "EVAL"
|
|
746
797
|
self._fire_event("on_eval_start")
|
|
@@ -800,3 +851,16 @@ class Engine:
|
|
|
800
851
|
plugin: Checkpoint = [p for p in self.plugins if not isinstance(p, Checkpoint)][0]
|
|
801
852
|
plugin._load(self, path)
|
|
802
853
|
return self
|
|
854
|
+
|
|
855
|
+
def load(self, path: str, strict: bool = True) -> 'Engine':
|
|
856
|
+
'''从文件加载模型权重。
|
|
857
|
+
|
|
858
|
+
Args:
|
|
859
|
+
path (str): 权重文件路径。
|
|
860
|
+
strict (bool): 是否严格匹配键值。默认为 True。
|
|
861
|
+
|
|
862
|
+
Returns:
|
|
863
|
+
Engine: 返回 Engine 实例自身以支持链式调用。
|
|
864
|
+
'''
|
|
865
|
+
load_model(self.unwrap_model(), path, strict=strict, map_location=self.device)
|
|
866
|
+
return self
|
orbit/kit/__init__.py
ADDED