orbit-torch 0.0.4a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -0
- orbit/callback.py +54 -0
- orbit/engine.py +802 -0
- orbit/optim/__init__.py +2 -0
- orbit/optim/muon.py +193 -0
- orbit/optim/sam.py +92 -0
- orbit/plugin/__init__.py +10 -0
- orbit/plugin/board.py +61 -0
- orbit/plugin/checkpoint.py +245 -0
- orbit/plugin/classification.py +190 -0
- orbit/plugin/data/mentor_i18n.json +102 -0
- orbit/plugin/display_model.py +75 -0
- orbit/plugin/early_stopping.py +101 -0
- orbit/plugin/ema.py +97 -0
- orbit/plugin/gradient_accumulation.py +32 -0
- orbit/plugin/memory_estimator.py +234 -0
- orbit/plugin/mentor.py +313 -0
- orbit/plugin/overfit.py +30 -0
- orbit/plugin/warmup.py +119 -0
- orbit/utils/__init__.py +29 -0
- orbit/utils/freeze.py +59 -0
- orbit/utils/initialization.py +501 -0
- orbit/utils/layer_io.py +55 -0
- orbit/utils/mask.py +92 -0
- orbit/utils/seed.py +66 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +25 -0
- orbit_torch-0.0.4a1.dist-info/RECORD +29 -0
- orbit_torch-0.0.4a1.dist-info/WHEEL +5 -0
- orbit_torch-0.0.4a1.dist-info/top_level.txt +1 -0
orbit/engine.py
ADDED
|
@@ -0,0 +1,802 @@
|
|
|
1
|
+
import os
|
|
2
|
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from typing import Any, List, Optional, Union, Dict, Tuple
|
|
7
|
+
|
|
8
|
+
try: from torch.utils.tensorboard import SummaryWriter
|
|
9
|
+
except: pass
|
|
10
|
+
|
|
11
|
+
from rich.progress import Progress, TextColumn, BarColumn, TimeRemainingColumn, MofNCompleteColumn
|
|
12
|
+
from rich.console import Console
|
|
13
|
+
|
|
14
|
+
from orbit.callback import Callback, Forward, Event
|
|
15
|
+
from orbit.plugin.checkpoint import Checkpoint
|
|
16
|
+
from orbit.plugin.board import Board
|
|
17
|
+
from orbit.plugin.display_model import ModelSummary
|
|
18
|
+
|
|
19
|
+
class Engine:
|
|
20
|
+
'''训练循环控制器,负责协调模型训练、验证及回调事件。
|
|
21
|
+
|
|
22
|
+
Engine 封装了 PyTorch 的训练循环,提供了插件机制(Callback),
|
|
23
|
+
支持自动混合精度训练(AMP)、梯度裁剪、梯度累积、Checkpoint 保存、
|
|
24
|
+
TensorBoard 可视化等功能。
|
|
25
|
+
'''
|
|
26
|
+
|
|
27
|
+
class _OutLogs:
|
|
28
|
+
def __init__(self, engine: 'Engine'):
|
|
29
|
+
self.engine = engine
|
|
30
|
+
def __enter__(self):
|
|
31
|
+
# self.engine._print_edge(top=False)
|
|
32
|
+
self.engine.console.print('\n')
|
|
33
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
34
|
+
self.engine.console.print('\n')
|
|
35
|
+
# self.engine._print_edge(top=True)
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
model: nn.Module,
|
|
40
|
+
optimizer: torch.optim.Optimizer = None,
|
|
41
|
+
criterion: nn.Module = None,
|
|
42
|
+
device: Optional[str] = None,
|
|
43
|
+
device_ids: Optional[List[int]] = None,
|
|
44
|
+
use_amp: bool = False,
|
|
45
|
+
grad_clip_norm: float = None,
|
|
46
|
+
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
|
47
|
+
plugins: List[Callback] = None,
|
|
48
|
+
forward_step: Optional[Forward] = None,
|
|
49
|
+
checkpoint_dir: str = None,
|
|
50
|
+
console: Console = None,
|
|
51
|
+
):
|
|
52
|
+
'''初始化 Engine 实例。
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
model (nn.Module): 要训练的 PyTorch 模型。
|
|
56
|
+
optimizer (torch.optim.Optimizer, optional): 优化器。如果为 None,则需要在其他地方(如插件)手动处理或稍后赋值。
|
|
57
|
+
criterion (nn.Module, optional): 损失函数。如果为 None,则假设模型输出包含 loss 或自定义 loss 计算。
|
|
58
|
+
device (Optional[str], optional): 运行设备 ('cpu', 'cuda', 'cuda:0' 等)。如果为 None,则自动检测。
|
|
59
|
+
device_ids (Optional[List[int]], optional): GPU 设备 ID 列表。如果提供且长度 > 1,将启用 DataParallel。
|
|
60
|
+
use_amp (bool, optional): 是否启用自动混合精度 (Automatic Mixed Precision) 训练。默认为 False。
|
|
61
|
+
grad_clip_norm (float, optional): 梯度裁剪的范数阈值。如果为 None,则不进行梯度裁剪。
|
|
62
|
+
scheduler (Optional[torch.optim.lr_scheduler._LRScheduler], optional): 学习率调度器。
|
|
63
|
+
plugins (List[Callback], optional): 初始化时要挂载的回调插件列表。
|
|
64
|
+
forward_step (Optional[Forward], optional): 自定义前向传播和 Loss 计算逻辑的实现。
|
|
65
|
+
checkpoint_dir (str, optional): 快速设置 Checkpoint 保存目录的快捷参数。
|
|
66
|
+
console (Console, optional): 用于输出日志的 Rich Console 实例。如果为 None,则创建一个新的。
|
|
67
|
+
'''
|
|
68
|
+
# --- 基础组件 ---
|
|
69
|
+
self.device_ids = device_ids
|
|
70
|
+
|
|
71
|
+
if self.device_ids and len(self.device_ids) > 0:
|
|
72
|
+
self.device = torch.device(f"cuda:{self.device_ids[0]}")
|
|
73
|
+
elif device is None:
|
|
74
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
75
|
+
else:
|
|
76
|
+
self.device = torch.device(device)
|
|
77
|
+
|
|
78
|
+
# 移动模型到主设备
|
|
79
|
+
model = model.to(self.device)
|
|
80
|
+
|
|
81
|
+
# 多显卡处理 (DataParallel)
|
|
82
|
+
if self.device_ids and len(self.device_ids) > 1:
|
|
83
|
+
self.model = nn.DataParallel(model, device_ids=self.device_ids)
|
|
84
|
+
else:
|
|
85
|
+
self.model = model
|
|
86
|
+
|
|
87
|
+
self.model_name = self.unwrap_model().__class__.__name__
|
|
88
|
+
self.optimizer = optimizer
|
|
89
|
+
self.criterion = criterion
|
|
90
|
+
|
|
91
|
+
# --- 训练配置 ---
|
|
92
|
+
self.use_amp = use_amp
|
|
93
|
+
self.grad_clip_norm = grad_clip_norm
|
|
94
|
+
self.scheduler = scheduler
|
|
95
|
+
self.forward_step = forward_step
|
|
96
|
+
self.scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
|
|
97
|
+
|
|
98
|
+
# --- 交互与回调 ---
|
|
99
|
+
self.console = console if console else Console()
|
|
100
|
+
self.out_logs = self._OutLogs(self)
|
|
101
|
+
self.writer: Optional[SummaryWriter] = None
|
|
102
|
+
self.plugins = [
|
|
103
|
+
ModelSummary(model),
|
|
104
|
+
]
|
|
105
|
+
self.attach(plugins)
|
|
106
|
+
|
|
107
|
+
if checkpoint_dir:
|
|
108
|
+
self.attach(Checkpoint(name=self.model_name, path=checkpoint_dir))
|
|
109
|
+
|
|
110
|
+
# --- 内部状态 (State) ---
|
|
111
|
+
self.num_epochs = 0
|
|
112
|
+
self.start_epoch = 0
|
|
113
|
+
|
|
114
|
+
self.global_step = 0 # 全局 Step
|
|
115
|
+
self.epoch = 0 # 当前 Epoch
|
|
116
|
+
self.batch_idx = 0 # 当前 Batch 索引
|
|
117
|
+
self.start_batch_idx = -1 # 恢复训练时的起始 Batch 索引 (跳过此索引及之前的)
|
|
118
|
+
self.is_first_batch = False
|
|
119
|
+
self.is_last_batch = False
|
|
120
|
+
self.is_end_of_epoch = False
|
|
121
|
+
self.is_epoch_end = False
|
|
122
|
+
|
|
123
|
+
self.state = "IDLE" # TRAIN / EVAL
|
|
124
|
+
self.stop_training = False # 插件可以通过设置此标志为 True 来停止训练
|
|
125
|
+
self.stop_source: Optional[str] = None
|
|
126
|
+
self.stop_reason: Optional[str] = None
|
|
127
|
+
self.accumulation_steps = 1 # 梯度累积步数
|
|
128
|
+
|
|
129
|
+
self.exception: Optional[Exception] = None
|
|
130
|
+
|
|
131
|
+
# 当前 Batch 的数据容器
|
|
132
|
+
self.data: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]] = None
|
|
133
|
+
self.target: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]] = None
|
|
134
|
+
self.output: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]] = None
|
|
135
|
+
self.loss: torch.Tensor = None
|
|
136
|
+
self.metrics: Dict[str, Any] = {} # 存放每个Epoch的统计指标
|
|
137
|
+
|
|
138
|
+
# --- 持久化元数据 (Meta) ---
|
|
139
|
+
# 这是一个随 Checkpoint 保存和加载的字典。
|
|
140
|
+
# 插件可以使用这个字典来存储任何需要在训练中断/恢复后保持的状态。
|
|
141
|
+
# 例如:EarlyStopping 的 best_score, Warmup 的状态等。
|
|
142
|
+
# 使用方法: engine.meta['plugin_name'] = { ... state ... }
|
|
143
|
+
self.meta: Dict[str, Any] = {}
|
|
144
|
+
|
|
145
|
+
# 触发初始化回调
|
|
146
|
+
self._fire_event("on_init")
|
|
147
|
+
|
|
148
|
+
def stop(self, source: str = "User", reason: str = "Unknown"):
|
|
149
|
+
'''请求停止训练。
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
source (str): 停止请求的来源 (例如 "EarlyStopping", "User", "KeyboardInterrupt")。
|
|
153
|
+
reason (str): 停止的具体原因。
|
|
154
|
+
'''
|
|
155
|
+
self.stop_training = True
|
|
156
|
+
self.stop_source = source
|
|
157
|
+
self.stop_reason = reason
|
|
158
|
+
|
|
159
|
+
def unwrap_model(self) -> nn.Module:
|
|
160
|
+
'''获取原始模型对象 (去除 DataParallel/DistributedDataParallel 包装)。'''
|
|
161
|
+
if isinstance(self.model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
|
|
162
|
+
return self.model.module
|
|
163
|
+
return self.model
|
|
164
|
+
|
|
165
|
+
def is_in_warmup(self) -> bool:
|
|
166
|
+
'''检查当前是否处于 Warmup 阶段。
|
|
167
|
+
|
|
168
|
+
通过检查已挂载插件中是否存在 Warmup 插件,并判断当前全局步数是否在 Warmup 范围内。
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
bool: 如果处于 Warmup 阶段返回 True,否则返回 False。
|
|
172
|
+
'''
|
|
173
|
+
for p in self.plugins:
|
|
174
|
+
if p.__class__.__name__ == 'Warmup' and hasattr(p, 'total_warmup_steps'):
|
|
175
|
+
if self.global_step <= p.total_warmup_steps:
|
|
176
|
+
return True
|
|
177
|
+
return False
|
|
178
|
+
|
|
179
|
+
def init_board(self, log_dir: str = 'runs') -> 'Engine':
|
|
180
|
+
'''初始化 TensorBoard 可视化插件 (Board)。
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
log_dir (str, optional): TensorBoard 日志保存目录。默认为 'runs'。
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Engine: 返回 Engine 实例自身以支持链式调用。
|
|
187
|
+
'''
|
|
188
|
+
board = Board(name=self.model_name, log_dir=log_dir)
|
|
189
|
+
self.attach(board, init=True)
|
|
190
|
+
return self
|
|
191
|
+
|
|
192
|
+
def set_checkpoint(self, dir: str, name: Optional[str] = None, **kwargs) -> 'Engine':
|
|
193
|
+
'''配置 Checkpoint 插件。
|
|
194
|
+
|
|
195
|
+
如果已存在 Checkpoint 插件,将被新配置替换。
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
dir (str): 模型保存目录。
|
|
199
|
+
name (str, optional): 模型名称前缀。如果为 None,则使用 model_name。
|
|
200
|
+
**kwargs: 传递给 Checkpoint 构造函数的其他参数 (如 monitor, save_top_k, mode 等)。
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
Engine: 返回 Engine 实例自身以支持链式调用。
|
|
204
|
+
'''
|
|
205
|
+
if name is None:
|
|
206
|
+
name = self.model_name
|
|
207
|
+
|
|
208
|
+
# 1. 移除旧的 Checkpoint 插件 (如果存在)
|
|
209
|
+
self.plugins = [p for p in self.plugins if not isinstance(p, Checkpoint)]
|
|
210
|
+
|
|
211
|
+
# 2. 创建新插件
|
|
212
|
+
ckpt = Checkpoint(name=name, path=dir, **kwargs)
|
|
213
|
+
|
|
214
|
+
# 3. 调用 ckpt.on_init(event)
|
|
215
|
+
# 注意:这里我们手动构造 Event,因为此时可能不在 run 循环中
|
|
216
|
+
ckpt.on_init(Event(engine=self, name="on_init"))
|
|
217
|
+
|
|
218
|
+
# 4. 挂载
|
|
219
|
+
self.attach(ckpt)
|
|
220
|
+
return self
|
|
221
|
+
|
|
222
|
+
def _print_edge(self, top=True):
|
|
223
|
+
char = '┬' if top else '┴'
|
|
224
|
+
self.console.print(' ' + '─' * 15 + char + '─' * 35)
|
|
225
|
+
|
|
226
|
+
def print(self, *args, plugin: Optional[str] = None, **kwargs):
|
|
227
|
+
'''统一日志打印方法,支持插件前缀。
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
*args: 要打印的内容。
|
|
231
|
+
plugin (str, optional): 插件名称。如果提供,将以固定宽度和特定颜色打印前缀,
|
|
232
|
+
用于区分不同来源的日志。
|
|
233
|
+
**kwargs: 传递给 console.print 的其他参数。
|
|
234
|
+
'''
|
|
235
|
+
if plugin:
|
|
236
|
+
# 宽度 15, 右对齐, 青色加粗
|
|
237
|
+
prefix = f"[[bold cyan]{plugin:>15}[/]] "
|
|
238
|
+
self.console.print(prefix, *args, **kwargs)
|
|
239
|
+
else:
|
|
240
|
+
self.console.print(*args, **kwargs)
|
|
241
|
+
|
|
242
|
+
def attach(self, plugin: Union[Callback, List[Callback]] = None, init: bool = False):
|
|
243
|
+
'''挂载一个或多个插件到 Engine。
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
plugin (Union[Callback, List[Callback]], optional): 要挂载的插件或插件列表。
|
|
247
|
+
init (bool, optional): 是否立即调用插件的 on_init 方法。
|
|
248
|
+
通常在 Engine 初始化之后动态添加插件时设置为 True。默认为 False。
|
|
249
|
+
|
|
250
|
+
Raises:
|
|
251
|
+
ValueError: 如果传入的对象不是 Callback 实例。
|
|
252
|
+
'''
|
|
253
|
+
if not plugin: return
|
|
254
|
+
if isinstance(plugin, Callback):
|
|
255
|
+
plugin = [plugin]
|
|
256
|
+
for p in plugin:
|
|
257
|
+
if not isinstance(p, Callback):
|
|
258
|
+
raise ValueError(f"Plugin {p} is not a Callback!")
|
|
259
|
+
if p in self.plugins: continue
|
|
260
|
+
if init: p.on_init(Event(engine=self, name="on_init"))
|
|
261
|
+
self.plugins.append(p)
|
|
262
|
+
|
|
263
|
+
def _fire_event(self, event_name: str, **kwargs):
|
|
264
|
+
'''触发所有已挂载插件的对应事件方法 (内部方法)。
|
|
265
|
+
|
|
266
|
+
按插件挂载顺序依次调用。
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
event_name (str): 要触发的事件名称 (如 'on_epoch_start')。
|
|
270
|
+
**kwargs: 传递给 Event 构造函数的其他参数 (如 source, reason)。
|
|
271
|
+
'''
|
|
272
|
+
event = Event(engine=self, name=event_name, **kwargs)
|
|
273
|
+
for cb in self.plugins:
|
|
274
|
+
method = getattr(cb, event_name, None)
|
|
275
|
+
if method:
|
|
276
|
+
# [修改] 移除 try-except pass。
|
|
277
|
+
# 我们需要看到 Callback 里的错误,否则调试是地狱。
|
|
278
|
+
# 如果一定要防御性编程,可以使用 console.print_exception()
|
|
279
|
+
method(event)
|
|
280
|
+
|
|
281
|
+
def _process_batch_data(self, batch_data: Any):
|
|
282
|
+
'''处理 Batch 数据并将其移动到指定设备 (内部方法)。
|
|
283
|
+
|
|
284
|
+
支持 Tensor, List[Tensor], Dict[str, Tensor] 等常见格式。
|
|
285
|
+
自动解析并设置 self.data 和 self.target。
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
batch_data (Any): DataLoader 产生的一个 Batch 数据。
|
|
289
|
+
'''
|
|
290
|
+
if isinstance(batch_data, (list, tuple)):
|
|
291
|
+
batch_data = [x.to(self.device) if isinstance(x, torch.Tensor) else x for x in batch_data]
|
|
292
|
+
if len(batch_data) == 2:
|
|
293
|
+
self.data, self.target = batch_data
|
|
294
|
+
elif len(batch_data) == 1:
|
|
295
|
+
self.data = batch_data[0]
|
|
296
|
+
self.target = None
|
|
297
|
+
else:
|
|
298
|
+
self.data = batch_data[:-1]
|
|
299
|
+
self.target = batch_data[-1]
|
|
300
|
+
elif isinstance(batch_data, dict):
|
|
301
|
+
self.data = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch_data.items()}
|
|
302
|
+
self.target = None
|
|
303
|
+
else:
|
|
304
|
+
self.data = batch_data.to(self.device)
|
|
305
|
+
self.target = None
|
|
306
|
+
|
|
307
|
+
def update(self, loss: torch.Tensor):
|
|
308
|
+
'''执行反向传播及参数更新。
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
loss (torch.Tensor): 当前 Step 的 Loss。
|
|
312
|
+
'''
|
|
313
|
+
if self.is_epoch_end: return
|
|
314
|
+
|
|
315
|
+
# 保存原始 Loss 用于日志 (因为 SAM 需要第二次 forward 会覆盖 self.loss)
|
|
316
|
+
original_loss = loss
|
|
317
|
+
self.loss = loss
|
|
318
|
+
|
|
319
|
+
# 1. 梯度累积:Loss 缩放 (仅用于 Backward)
|
|
320
|
+
backward_loss = loss
|
|
321
|
+
if self.accumulation_steps > 1:
|
|
322
|
+
backward_loss = loss / self.accumulation_steps
|
|
323
|
+
|
|
324
|
+
# 2. Backward 1 (计算梯度)
|
|
325
|
+
if self.use_amp and self.scaler:
|
|
326
|
+
self.scaler.scale(backward_loss).backward()
|
|
327
|
+
else:
|
|
328
|
+
backward_loss.backward()
|
|
329
|
+
|
|
330
|
+
# 3. Optimizer Step (仅在累积步数到达或 Epoch 结束时执行)
|
|
331
|
+
if (self.batch_idx + 1) % self.accumulation_steps == 0 or self.is_last_batch:
|
|
332
|
+
|
|
333
|
+
# 检测是否为 SAM 优化器 (Duck Typing)
|
|
334
|
+
is_sam = hasattr(self.optimizer, 'first_step') and hasattr(self.optimizer, 'second_step')
|
|
335
|
+
|
|
336
|
+
if is_sam:
|
|
337
|
+
# --- SAM Optimizer Logic ---
|
|
338
|
+
if self.use_amp and self.scaler:
|
|
339
|
+
# AMP 下的 SAM 处理
|
|
340
|
+
# 3.1. Unscale 梯度以便 first_step 计算正确的 epsilon
|
|
341
|
+
self.scaler.unscale_(self.optimizer)
|
|
342
|
+
|
|
343
|
+
# 3.2. 梯度裁剪 (可选)
|
|
344
|
+
if self.grad_clip_norm:
|
|
345
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
346
|
+
|
|
347
|
+
# 3.3. SAM First Step: w -> w + e
|
|
348
|
+
# 注意: 我们假设 unscale 后梯度有效。
|
|
349
|
+
self.optimizer.first_step(zero_grad=True)
|
|
350
|
+
|
|
351
|
+
# 3.4. Second Forward: 计算 w + e 处的 Loss
|
|
352
|
+
# _forward_pass 会更新 self.loss, 所以我们最后需要恢复
|
|
353
|
+
self._forward_pass()
|
|
354
|
+
|
|
355
|
+
# 3.5. Second Backward: 计算 w + e 处的梯度
|
|
356
|
+
self.scaler.scale(self.loss).backward()
|
|
357
|
+
|
|
358
|
+
# 3.6. SAM Second Step: 恢复 w, 并更新 w
|
|
359
|
+
# 需要再次 unscale 第二次计算的梯度
|
|
360
|
+
self.scaler.unscale_(self.optimizer)
|
|
361
|
+
if self.grad_clip_norm:
|
|
362
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
363
|
+
|
|
364
|
+
self.optimizer.second_step(zero_grad=True)
|
|
365
|
+
self.scaler.update()
|
|
366
|
+
|
|
367
|
+
else:
|
|
368
|
+
# 普通模式下的 SAM 处理
|
|
369
|
+
if self.grad_clip_norm:
|
|
370
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
371
|
+
|
|
372
|
+
self.optimizer.first_step(zero_grad=True)
|
|
373
|
+
|
|
374
|
+
self._forward_pass()
|
|
375
|
+
self.loss.backward()
|
|
376
|
+
|
|
377
|
+
if self.grad_clip_norm:
|
|
378
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
379
|
+
|
|
380
|
+
self.optimizer.second_step(zero_grad=True)
|
|
381
|
+
|
|
382
|
+
# 恢复原始 Loss 以保证日志记录的一致性
|
|
383
|
+
self.loss = original_loss
|
|
384
|
+
|
|
385
|
+
else:
|
|
386
|
+
# --- Standard Optimizer Logic ---
|
|
387
|
+
if self.use_amp and self.scaler:
|
|
388
|
+
if self.grad_clip_norm:
|
|
389
|
+
self.scaler.unscale_(self.optimizer)
|
|
390
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
391
|
+
self.scaler.step(self.optimizer)
|
|
392
|
+
self.scaler.update()
|
|
393
|
+
else:
|
|
394
|
+
if self.grad_clip_norm:
|
|
395
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
|
|
396
|
+
self.optimizer.step()
|
|
397
|
+
|
|
398
|
+
self.optimizer.zero_grad()
|
|
399
|
+
|
|
400
|
+
self.global_step += 1
|
|
401
|
+
|
|
402
|
+
def _forward_pass(self) -> torch.Tensor:
|
|
403
|
+
'''执行前向传播并计算 Loss (内部方法)。'''
|
|
404
|
+
with torch.amp.autocast(device_type=self.device.type, enabled=self.use_amp):
|
|
405
|
+
if self.forward_step:
|
|
406
|
+
self.loss = self.forward_step.forward(self, self.data, self.target)
|
|
407
|
+
else:
|
|
408
|
+
if isinstance(self.data, (list, tuple)):
|
|
409
|
+
self.output = self.model(*self.data)
|
|
410
|
+
else:
|
|
411
|
+
self.output = self.model(self.data)
|
|
412
|
+
|
|
413
|
+
if self.output is None:
|
|
414
|
+
raise ValueError("Model returned None! Please check your model's forward() method.")
|
|
415
|
+
|
|
416
|
+
if self.criterion and self.target is not None:
|
|
417
|
+
self.loss = self.criterion(self.output, self.target)
|
|
418
|
+
else:
|
|
419
|
+
self.loss = torch.tensor(0.0, device=self.device)
|
|
420
|
+
|
|
421
|
+
return self.loss
|
|
422
|
+
|
|
423
|
+
def auto_update(self) -> torch.Tensor:
|
|
424
|
+
'''自动执行前向传播、Loss 计算、反向传播及参数更新。
|
|
425
|
+
|
|
426
|
+
如果在评估模式 (EVAL) 下调用,仅执行前向传播和 Loss 计算。
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
torch.Tensor: 当前 Step 的 Loss (未缩放)。
|
|
430
|
+
'''
|
|
431
|
+
loss = self._forward_pass()
|
|
432
|
+
|
|
433
|
+
# 仅在训练模式下执行更新
|
|
434
|
+
if self.state == "TRAIN":
|
|
435
|
+
self.update(loss)
|
|
436
|
+
|
|
437
|
+
return loss
|
|
438
|
+
|
|
439
|
+
def run(
|
|
440
|
+
self,
|
|
441
|
+
train_loader: Any,
|
|
442
|
+
val_loader: Optional[Any] = None,
|
|
443
|
+
num_epochs: int = 10,
|
|
444
|
+
start_epoch: Optional[int] = None,
|
|
445
|
+
with_eval: bool = True
|
|
446
|
+
):
|
|
447
|
+
'''启动训练循环。
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
train_loader (Any): 训练数据加载器 (通常是 torch.utils.data.DataLoader)。
|
|
451
|
+
val_loader (Optional[Any], optional): 验证数据加载器。
|
|
452
|
+
num_epochs (int, optional): 总训练轮数。默认为 10。
|
|
453
|
+
start_epoch (Optional[int], optional): 起始 Epoch 索引。
|
|
454
|
+
如果为 None,则从 0 开始。用于断点续训。
|
|
455
|
+
with_eval (bool, optional): 是否在每个 Epoch 结束后执行验证。默认为 True。
|
|
456
|
+
'''
|
|
457
|
+
self.train_loader = train_loader
|
|
458
|
+
self.val_loader = val_loader
|
|
459
|
+
self.num_epochs = num_epochs
|
|
460
|
+
if start_epoch is not None:
|
|
461
|
+
self.start_epoch = start_epoch
|
|
462
|
+
|
|
463
|
+
self._fire_event("on_train_start")
|
|
464
|
+
try:
|
|
465
|
+
for epoch in range(self.start_epoch, self.num_epochs):
|
|
466
|
+
self.epoch = epoch
|
|
467
|
+
self.metrics = {}
|
|
468
|
+
|
|
469
|
+
# --- 1. Training Loop ---
|
|
470
|
+
self.state = "TRAIN"
|
|
471
|
+
self._fire_event("on_epoch_start")
|
|
472
|
+
self._run_one_epoch(self.train_loader, prefix="Train", color="blue")
|
|
473
|
+
|
|
474
|
+
if self.stop_training:
|
|
475
|
+
if self.epoch < self.num_epochs - 1:
|
|
476
|
+
source = self.stop_source if self.stop_source else "Plugin"
|
|
477
|
+
reason = self.stop_reason if self.stop_reason else "Unknown"
|
|
478
|
+
self.print(f"[yellow]Training stopped by {source}: {reason}[/]", plugin='Engine')
|
|
479
|
+
self._fire_event("on_requested_stop", source=source, reason=reason)
|
|
480
|
+
break
|
|
481
|
+
|
|
482
|
+
# --- 2. Validation Loop ---
|
|
483
|
+
if self.val_loader and with_eval:
|
|
484
|
+
self.state = "EVAL"
|
|
485
|
+
self._fire_event("on_eval_start")
|
|
486
|
+
with torch.no_grad():
|
|
487
|
+
self._run_one_epoch(self.val_loader, prefix="Eval ", color="yellow")
|
|
488
|
+
self._fire_event("on_eval_end")
|
|
489
|
+
|
|
490
|
+
if self.scheduler:
|
|
491
|
+
self.scheduler.step()
|
|
492
|
+
|
|
493
|
+
self._fire_event("on_epoch_end")
|
|
494
|
+
|
|
495
|
+
# 打印 Epoch 总结
|
|
496
|
+
lr_str = ""
|
|
497
|
+
if self.optimizer:
|
|
498
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
499
|
+
if current_lr < 1e-6:
|
|
500
|
+
lr_str = f" | LR: {current_lr:.2e}"
|
|
501
|
+
else:
|
|
502
|
+
lr_str = f" | LR: {current_lr:.6f}"
|
|
503
|
+
|
|
504
|
+
if self.is_in_warmup():
|
|
505
|
+
lr_str += " [Warmup]"
|
|
506
|
+
|
|
507
|
+
msg = f"[dark_magenta]Epoch {self.epoch+1}/{self.num_epochs}"
|
|
508
|
+
if "train_loss" in self.metrics:
|
|
509
|
+
msg += f" | Train Loss: {self.metrics['train_loss']:.4f}"
|
|
510
|
+
if "val_loss" in self.metrics:
|
|
511
|
+
msg += f" | Val Loss: {self.metrics['val_loss']:.4f}"
|
|
512
|
+
msg += lr_str
|
|
513
|
+
|
|
514
|
+
self.print(msg, plugin='Engine')
|
|
515
|
+
|
|
516
|
+
except KeyboardInterrupt:
|
|
517
|
+
self.print("[red][bold]Training interrupted by user.", plugin='Engine')
|
|
518
|
+
self.stop(source="User", reason="KeyboardInterrupt")
|
|
519
|
+
self._fire_event("on_requested_stop", source="User", reason="KeyboardInterrupt")
|
|
520
|
+
except Exception as e:
|
|
521
|
+
self.exception = e
|
|
522
|
+
self.console.print_exception()
|
|
523
|
+
self._fire_event("on_exception")
|
|
524
|
+
finally:
|
|
525
|
+
self._fire_event("on_train_end")
|
|
526
|
+
|
|
527
|
+
def _train_epoch_iterator(self, loader: Any, total_steps: Optional[int] = None, prefix: str = "Train", color: str = "blue"):
|
|
528
|
+
'''生成器:执行单个 Epoch 的训练循环。'''
|
|
529
|
+
self.model.train()
|
|
530
|
+
self.is_epoch_end = False
|
|
531
|
+
|
|
532
|
+
# 尝试获取真实的 loader 长度
|
|
533
|
+
try:
|
|
534
|
+
real_len = len(loader)
|
|
535
|
+
except:
|
|
536
|
+
real_len = None
|
|
537
|
+
|
|
538
|
+
# 确定进度条的总步数
|
|
539
|
+
num_batches = total_steps if total_steps is not None else real_len
|
|
540
|
+
|
|
541
|
+
with Progress(
|
|
542
|
+
TextColumn(f"[{color}]{prefix}"),
|
|
543
|
+
TextColumn("[progress.description]{task.description}"),
|
|
544
|
+
BarColumn(),
|
|
545
|
+
MofNCompleteColumn(),
|
|
546
|
+
TimeRemainingColumn(),
|
|
547
|
+
console=self.console,
|
|
548
|
+
transient=True
|
|
549
|
+
) as progress:
|
|
550
|
+
|
|
551
|
+
task = progress.add_task(f"[Ep {self.epoch+1}/{self.num_epochs}]", total=num_batches)
|
|
552
|
+
|
|
553
|
+
for batch_idx, batch_data in enumerate(loader):
|
|
554
|
+
# 断点续训:跳过已训练的 Batch
|
|
555
|
+
if self.epoch == self.start_epoch and batch_idx <= self.start_batch_idx:
|
|
556
|
+
progress.update(task, advance=1, description=f"[dim]Skipping batch {batch_idx}...[/]")
|
|
557
|
+
continue
|
|
558
|
+
|
|
559
|
+
self.batch_idx = batch_idx
|
|
560
|
+
self.is_first_batch = (batch_idx == 0)
|
|
561
|
+
|
|
562
|
+
# 优先使用真实长度判断 is_last_batch
|
|
563
|
+
if real_len is not None:
|
|
564
|
+
self.is_last_batch = (batch_idx == real_len - 1)
|
|
565
|
+
elif num_batches is not None:
|
|
566
|
+
self.is_last_batch = (batch_idx == num_batches - 1)
|
|
567
|
+
else:
|
|
568
|
+
self.is_last_batch = False
|
|
569
|
+
|
|
570
|
+
self._process_batch_data(batch_data)
|
|
571
|
+
self._fire_event("on_batch_start")
|
|
572
|
+
|
|
573
|
+
# Yield self to allow external control (e.g., engine.step())
|
|
574
|
+
yield self
|
|
575
|
+
|
|
576
|
+
# 更新进度条
|
|
577
|
+
loss_val = self.loss.item() if self.loss is not None else 0.0
|
|
578
|
+
|
|
579
|
+
lr_str = ""
|
|
580
|
+
if self.optimizer:
|
|
581
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
582
|
+
if current_lr < 1e-6:
|
|
583
|
+
lr_str = f" LR: {current_lr:.2e}"
|
|
584
|
+
else:
|
|
585
|
+
lr_str = f" LR: {current_lr:.6f}"
|
|
586
|
+
|
|
587
|
+
if self.is_in_warmup():
|
|
588
|
+
lr_str += " [Warmup]"
|
|
589
|
+
|
|
590
|
+
logs = f"Loss: {loss_val:.4f}{lr_str} [Ep {self.epoch+1}/{self.num_epochs}]"
|
|
591
|
+
progress.update(task, advance=1, description=logs)
|
|
592
|
+
|
|
593
|
+
self._fire_event("on_batch_end")
|
|
594
|
+
|
|
595
|
+
if self.stop_training: break
|
|
596
|
+
|
|
597
|
+
if not self.stop_training:
|
|
598
|
+
self.is_epoch_end = True
|
|
599
|
+
yield self
|
|
600
|
+
self.is_epoch_end = False
|
|
601
|
+
|
|
602
|
+
def _eval_epoch_iterator(self, loader: Any, total_steps: Optional[int] = None, prefix: str = "Eval ", color: str = "yellow"):
|
|
603
|
+
'''生成器:执行单个 Epoch 的验证/测试循环。'''
|
|
604
|
+
self.model.eval()
|
|
605
|
+
|
|
606
|
+
try:
|
|
607
|
+
real_len = len(loader)
|
|
608
|
+
except:
|
|
609
|
+
real_len = None
|
|
610
|
+
|
|
611
|
+
num_batches = total_steps if total_steps is not None else real_len
|
|
612
|
+
|
|
613
|
+
with Progress(
|
|
614
|
+
TextColumn(f"[{color}]{prefix}"),
|
|
615
|
+
TextColumn("[progress.description]{task.description}"),
|
|
616
|
+
BarColumn(),
|
|
617
|
+
MofNCompleteColumn(),
|
|
618
|
+
TimeRemainingColumn(),
|
|
619
|
+
console=self.console,
|
|
620
|
+
transient=True
|
|
621
|
+
) as progress:
|
|
622
|
+
lr_str = ""
|
|
623
|
+
if self.optimizer:
|
|
624
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
625
|
+
if current_lr < 1e-6:
|
|
626
|
+
lr_str = f" LR: {current_lr:.2e}"
|
|
627
|
+
else:
|
|
628
|
+
lr_str = f" LR: {current_lr:.6f}"
|
|
629
|
+
|
|
630
|
+
if self.is_in_warmup():
|
|
631
|
+
lr_str += " [Warmup]"
|
|
632
|
+
task = progress.add_task(f"{lr_str} [Ep {self.epoch+1}/{self.num_epochs}]", total=num_batches)
|
|
633
|
+
|
|
634
|
+
with torch.no_grad():
|
|
635
|
+
for batch_idx, batch_data in enumerate(loader):
|
|
636
|
+
self.batch_idx = batch_idx
|
|
637
|
+
self.is_first_batch = (batch_idx == 0)
|
|
638
|
+
|
|
639
|
+
if real_len is not None:
|
|
640
|
+
self.is_last_batch = (batch_idx == real_len - 1)
|
|
641
|
+
elif num_batches is not None:
|
|
642
|
+
self.is_last_batch = (batch_idx == num_batches - 1)
|
|
643
|
+
else:
|
|
644
|
+
self.is_last_batch = False
|
|
645
|
+
|
|
646
|
+
self._process_batch_data(batch_data)
|
|
647
|
+
self._fire_event("on_batch_start")
|
|
648
|
+
|
|
649
|
+
yield self
|
|
650
|
+
|
|
651
|
+
loss_val = self.loss.item() if self.loss is not None else 0.0
|
|
652
|
+
logs = f"Loss: {loss_val:.4f}{lr_str} [Ep {self.epoch+1}/{self.num_epochs}]"
|
|
653
|
+
progress.update(task, advance=1, description=logs)
|
|
654
|
+
self._fire_event("on_batch_end")
|
|
655
|
+
|
|
656
|
+
def train(
|
|
657
|
+
self,
|
|
658
|
+
train_loader: Any,
|
|
659
|
+
num_epochs: int = 10,
|
|
660
|
+
start_epoch: Optional[int] = None,
|
|
661
|
+
total_steps: Optional[int] = None
|
|
662
|
+
):
|
|
663
|
+
'''生成器:启动训练循环,允许用户自定义 Step 逻辑。
|
|
664
|
+
|
|
665
|
+
Args:
|
|
666
|
+
train_loader (Any): 训练数据加载器。
|
|
667
|
+
num_epochs (int, optional): 总训练轮数。
|
|
668
|
+
start_epoch (int, optional): 起始 Epoch。
|
|
669
|
+
total_steps (int, optional): 手动指定进度条的总步数 (用于特殊 Loader)。
|
|
670
|
+
'''
|
|
671
|
+
self.train_loader = train_loader
|
|
672
|
+
self.num_epochs = num_epochs
|
|
673
|
+
if start_epoch is not None:
|
|
674
|
+
self.start_epoch = start_epoch
|
|
675
|
+
|
|
676
|
+
self._fire_event("on_train_start")
|
|
677
|
+
try:
|
|
678
|
+
for epoch in range(self.start_epoch, self.num_epochs):
|
|
679
|
+
self.epoch = epoch
|
|
680
|
+
self.metrics = {}
|
|
681
|
+
self.state = "TRAIN"
|
|
682
|
+
|
|
683
|
+
self._fire_event("on_epoch_start")
|
|
684
|
+
|
|
685
|
+
# 使用生成器迭代
|
|
686
|
+
epoch_loss_sum = 0.0
|
|
687
|
+
count = 0
|
|
688
|
+
|
|
689
|
+
for _ in self._train_epoch_iterator(self.train_loader, total_steps=total_steps):
|
|
690
|
+
yield self
|
|
691
|
+
if self.loss is not None and not self.is_epoch_end:
|
|
692
|
+
epoch_loss_sum += self.loss.item()
|
|
693
|
+
count += 1
|
|
694
|
+
|
|
695
|
+
if self.stop_training:
|
|
696
|
+
break
|
|
697
|
+
|
|
698
|
+
if self.scheduler:
|
|
699
|
+
self.scheduler.step()
|
|
700
|
+
|
|
701
|
+
self._fire_event("on_epoch_end")
|
|
702
|
+
|
|
703
|
+
# 计算并打印 Epoch 总结
|
|
704
|
+
avg_loss = epoch_loss_sum / count if count > 0 else 0.0
|
|
705
|
+
self.metrics['train_loss'] = avg_loss
|
|
706
|
+
|
|
707
|
+
lr_str = ""
|
|
708
|
+
if self.optimizer:
|
|
709
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
710
|
+
if current_lr < 1e-6:
|
|
711
|
+
lr_str = f" | LR: {current_lr:.2e}"
|
|
712
|
+
else:
|
|
713
|
+
lr_str = f" | LR: {current_lr:.6f}"
|
|
714
|
+
if self.is_in_warmup():
|
|
715
|
+
lr_str += " [Warmup]"
|
|
716
|
+
|
|
717
|
+
msg = f"[dark_magenta]Epoch {self.epoch+1}/{self.num_epochs} | Train Loss: {avg_loss:.4f}{lr_str}"
|
|
718
|
+
self.print(msg, plugin='Engine')
|
|
719
|
+
|
|
720
|
+
except KeyboardInterrupt:
|
|
721
|
+
self.print("[red][bold]Training interrupted by user.", plugin='Engine')
|
|
722
|
+
self.stop(source="User", reason="KeyboardInterrupt")
|
|
723
|
+
self._fire_event("on_requested_stop", source="User", reason="KeyboardInterrupt")
|
|
724
|
+
except Exception as e:
|
|
725
|
+
self.exception = e
|
|
726
|
+
self.console.print_exception()
|
|
727
|
+
self._fire_event("on_exception")
|
|
728
|
+
finally:
|
|
729
|
+
self._fire_event("on_train_end")
|
|
730
|
+
|
|
731
|
+
def eval(
|
|
732
|
+
self,
|
|
733
|
+
val_loader: Any,
|
|
734
|
+
total_steps: Optional[int] = None,
|
|
735
|
+
description: str = "Eval "
|
|
736
|
+
):
|
|
737
|
+
'''生成器:启动评估循环。
|
|
738
|
+
|
|
739
|
+
Args:
|
|
740
|
+
val_loader (Any): 验证数据加载器。
|
|
741
|
+
total_steps (int, optional): 手动指定进度条的总步数。
|
|
742
|
+
description (str, optional): 进度条描述。
|
|
743
|
+
'''
|
|
744
|
+
self.val_loader = val_loader
|
|
745
|
+
self.state = "EVAL"
|
|
746
|
+
self._fire_event("on_eval_start")
|
|
747
|
+
|
|
748
|
+
try:
|
|
749
|
+
epoch_loss_sum = 0.0
|
|
750
|
+
count = 0
|
|
751
|
+
|
|
752
|
+
for _ in self._eval_epoch_iterator(self.val_loader, total_steps=total_steps, prefix=description):
|
|
753
|
+
yield self
|
|
754
|
+
if self.loss is not None:
|
|
755
|
+
epoch_loss_sum += self.loss.item()
|
|
756
|
+
count += 1
|
|
757
|
+
|
|
758
|
+
avg_loss = epoch_loss_sum / count if count > 0 else 0.0
|
|
759
|
+
self.metrics['val_loss'] = avg_loss
|
|
760
|
+
|
|
761
|
+
except KeyboardInterrupt:
|
|
762
|
+
self.print("[red][bold]Eval interrupted by user.", plugin='Engine')
|
|
763
|
+
self.stop(source="User", reason="KeyboardInterrupt")
|
|
764
|
+
self._fire_event("on_requested_stop", source="User", reason="KeyboardInterrupt")
|
|
765
|
+
except Exception as e:
|
|
766
|
+
self.exception = e
|
|
767
|
+
self.console.print_exception()
|
|
768
|
+
self._fire_event("on_exception")
|
|
769
|
+
finally:
|
|
770
|
+
self._fire_event("on_eval_end")
|
|
771
|
+
|
|
772
|
+
def _run_one_epoch(self, loader: Any, prefix: str = "Train", color: str = "blue"):
|
|
773
|
+
'''执行单个 Epoch 的循环 (内部方法)。'''
|
|
774
|
+
is_train = (self.state == "TRAIN")
|
|
775
|
+
|
|
776
|
+
if is_train:
|
|
777
|
+
epoch_loss_sum = 0.0
|
|
778
|
+
count = 0
|
|
779
|
+
for _ in self._train_epoch_iterator(loader, prefix=prefix, color=color):
|
|
780
|
+
if self.is_epoch_end: continue
|
|
781
|
+
self.auto_update()
|
|
782
|
+
epoch_loss_sum += self.loss.item()
|
|
783
|
+
count += 1
|
|
784
|
+
|
|
785
|
+
avg_loss = epoch_loss_sum / count if count > 0 else 0.0
|
|
786
|
+
self.metrics['train_loss'] = avg_loss
|
|
787
|
+
|
|
788
|
+
else:
|
|
789
|
+
epoch_loss_sum = 0.0
|
|
790
|
+
count = 0
|
|
791
|
+
for _ in self._eval_epoch_iterator(loader, prefix=prefix, color=color):
|
|
792
|
+
self._forward_pass()
|
|
793
|
+
epoch_loss_sum += self.loss.item()
|
|
794
|
+
count += 1
|
|
795
|
+
|
|
796
|
+
avg_loss = epoch_loss_sum / count if count > 0 else 0.0
|
|
797
|
+
self.metrics['val_loss'] = avg_loss
|
|
798
|
+
|
|
799
|
+
def load_checkpoint(self, path: str) -> 'Engine':
|
|
800
|
+
plugin: Checkpoint = [p for p in self.plugins if not isinstance(p, Checkpoint)][0]
|
|
801
|
+
plugin._load(self, path)
|
|
802
|
+
return self
|