orbit-torch 0.0.4a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -0
- orbit/callback.py +54 -0
- orbit/engine.py +802 -0
- orbit/optim/__init__.py +2 -0
- orbit/optim/muon.py +193 -0
- orbit/optim/sam.py +92 -0
- orbit/plugin/__init__.py +10 -0
- orbit/plugin/board.py +61 -0
- orbit/plugin/checkpoint.py +245 -0
- orbit/plugin/classification.py +190 -0
- orbit/plugin/data/mentor_i18n.json +102 -0
- orbit/plugin/display_model.py +75 -0
- orbit/plugin/early_stopping.py +101 -0
- orbit/plugin/ema.py +97 -0
- orbit/plugin/gradient_accumulation.py +32 -0
- orbit/plugin/memory_estimator.py +234 -0
- orbit/plugin/mentor.py +313 -0
- orbit/plugin/overfit.py +30 -0
- orbit/plugin/warmup.py +119 -0
- orbit/utils/__init__.py +29 -0
- orbit/utils/freeze.py +59 -0
- orbit/utils/initialization.py +501 -0
- orbit/utils/layer_io.py +55 -0
- orbit/utils/mask.py +92 -0
- orbit/utils/seed.py +66 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +25 -0
- orbit_torch-0.0.4a1.dist-info/RECORD +29 -0
- orbit_torch-0.0.4a1.dist-info/WHEEL +5 -0
- orbit_torch-0.0.4a1.dist-info/top_level.txt +1 -0
orbit/optim/__init__.py
ADDED
orbit/optim/muon.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
# 此代码片段改编自以下 GitHub 仓库的修改版本:
|
|
5
|
+
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
6
|
+
@torch.compile
|
|
7
|
+
def zeropower_via_newtonschulz5(G, steps):
|
|
8
|
+
'''
|
|
9
|
+
使用 Newton-Schulz 迭代计算 G 的零次幂/正交化。我们选择使用五次迭代,
|
|
10
|
+
其系数被选择为最大化零处的斜率。为了最小化步骤,经验表明,即使在迭代不再在区间上的
|
|
11
|
+
所有位置完全收敛到 1 的点之后,继续增加零处的斜率也是有效的。因此,此迭代不产生 UV^T,
|
|
12
|
+
而是产生类似 US'V^T 的结果,其中 S' 是对角矩阵,S_{ii}' ~ Uniform(0.5, 1.5),
|
|
13
|
+
这在模型性能方面相对于 UV^T(其中 USV^T = G 是 SVD)完全没有负面影响。
|
|
14
|
+
'''
|
|
15
|
+
assert len(G.shape) == 2
|
|
16
|
+
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
17
|
+
X = G.bfloat16()
|
|
18
|
+
if G.size(0) > G.size(1):
|
|
19
|
+
X = X.T
|
|
20
|
+
# 确保谱范数至多为 1
|
|
21
|
+
X = X / (X.norm() + 1e-7)
|
|
22
|
+
# 执行 NS 迭代
|
|
23
|
+
for _ in range(steps):
|
|
24
|
+
A = X @ X.T
|
|
25
|
+
B = (
|
|
26
|
+
b * A + c * A @ A
|
|
27
|
+
) # 改编自 @jxbz, @leloykun 和 @YouJiacheng 的建议
|
|
28
|
+
X = a * X + B @ X
|
|
29
|
+
|
|
30
|
+
if G.size(0) > G.size(1):
|
|
31
|
+
X = X.T
|
|
32
|
+
return X
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Muon(torch.optim.Optimizer):
|
|
36
|
+
'''
|
|
37
|
+
Muon - MomentUm Orthogonalized by Newton-schulz (通过 Newton-Schulz 正交化的动量)
|
|
38
|
+
|
|
39
|
+
Muon 内部运行标准的 SGD-momentum,然后执行正交化后处理步骤,
|
|
40
|
+
其中每个 2D 参数的更新都被替换为最近的正交矩阵。为了有效地正交化每个更新,
|
|
41
|
+
我们使用 Newton-Schulz 迭代,其优点是可以在 GPU 上以 bfloat16 稳定运行。
|
|
42
|
+
|
|
43
|
+
一些警告:
|
|
44
|
+
- 我们认为此优化器不太可能在小批量训练中表现良好。
|
|
45
|
+
- 我们认为它可能不适合微调预训练模型,但我们尚未对此进行测试。
|
|
46
|
+
|
|
47
|
+
参数:
|
|
48
|
+
muon_params: 要由 Muon 优化的参数。
|
|
49
|
+
lr: 学习率。更新的谱范数将为 `lr`。(0.02 是一个很好的默认值)
|
|
50
|
+
momentum: 内部 SGD 使用的动量。(0.95 是一个很好的默认值)
|
|
51
|
+
nesterov: 是否在内部 SGD 中使用 Nesterov 风格的动量。(推荐)
|
|
52
|
+
ns_steps: 要运行的 Newton-Schulz 迭代次数。(6 可能总是足够的)
|
|
53
|
+
adamw_params: 要由 AdamW 优化的参数。`muon_params` 中任何 {0, 1}-D 参数
|
|
54
|
+
或被检测为嵌入或 lm_head 的参数也将由 AdamW 优化。
|
|
55
|
+
adamw_lr: 内部 AdamW 的学习率。
|
|
56
|
+
adamw_betas: 内部 AdamW 的 beta 参数。
|
|
57
|
+
adamw_eps: 内部 AdamW 的 epsilon 参数。
|
|
58
|
+
adamw_wd: 内部 AdamW 的权重衰减。
|
|
59
|
+
'''
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
lr=1e-3,
|
|
64
|
+
wd=0.1,
|
|
65
|
+
muon_params=None,
|
|
66
|
+
momentum=0.95,
|
|
67
|
+
nesterov=True,
|
|
68
|
+
ns_steps=5,
|
|
69
|
+
adamw_params=None,
|
|
70
|
+
adamw_betas=(0.9, 0.95),
|
|
71
|
+
adamw_eps=1e-8,
|
|
72
|
+
):
|
|
73
|
+
|
|
74
|
+
defaults = dict(
|
|
75
|
+
lr=lr,
|
|
76
|
+
wd=wd,
|
|
77
|
+
momentum=momentum,
|
|
78
|
+
nesterov=nesterov,
|
|
79
|
+
ns_steps=ns_steps,
|
|
80
|
+
adamw_betas=adamw_betas,
|
|
81
|
+
adamw_eps=adamw_eps,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
params = list(muon_params)
|
|
85
|
+
adamw_params = list(adamw_params) if adamw_params is not None else []
|
|
86
|
+
params.extend(adamw_params)
|
|
87
|
+
super().__init__(params, defaults)
|
|
88
|
+
# 将参数分类为我们将使用 Muon 的参数和不使用的参数
|
|
89
|
+
for p in muon_params:
|
|
90
|
+
# 对 muon_params 中每个 >= 2D 且看起来不像嵌入或头层的参数使用 Muon
|
|
91
|
+
assert p.ndim == 2, p.ndim
|
|
92
|
+
self.state[p]['use_muon'] = True
|
|
93
|
+
for p in adamw_params:
|
|
94
|
+
# 对 adamw_params 中的参数不使用 Muon
|
|
95
|
+
self.state[p]['use_muon'] = False
|
|
96
|
+
|
|
97
|
+
def adjust_lr_for_muon(self, lr, param_shape):
|
|
98
|
+
A, B = param_shape[:2]
|
|
99
|
+
# 我们根据参数矩阵的大小调整学习率和权重衰减,如论文中所述
|
|
100
|
+
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
|
|
101
|
+
adjusted_lr = lr * adjusted_ratio
|
|
102
|
+
return adjusted_lr
|
|
103
|
+
|
|
104
|
+
def step(self, closure=None):
|
|
105
|
+
'''
|
|
106
|
+
执行单个优化步骤。
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
closure (Callable, optional): 重新评估模型并返回损失的闭包。
|
|
110
|
+
'''
|
|
111
|
+
loss = None
|
|
112
|
+
if closure is not None:
|
|
113
|
+
with torch.enable_grad():
|
|
114
|
+
loss = closure()
|
|
115
|
+
|
|
116
|
+
for group in self.param_groups:
|
|
117
|
+
|
|
118
|
+
############################
|
|
119
|
+
# Muon #
|
|
120
|
+
############################
|
|
121
|
+
|
|
122
|
+
params = [p for p in group['params'] if self.state[p]['use_muon']]
|
|
123
|
+
# import pdb; pdb.set_trace()
|
|
124
|
+
lr = group['lr']
|
|
125
|
+
wd = group['wd']
|
|
126
|
+
momentum = group['momentum']
|
|
127
|
+
|
|
128
|
+
# 生成权重更新
|
|
129
|
+
for p in params:
|
|
130
|
+
# 完整性检查
|
|
131
|
+
g = p.grad
|
|
132
|
+
if g is None:
|
|
133
|
+
continue
|
|
134
|
+
if g.ndim > 2:
|
|
135
|
+
g = g.view(g.size(0), -1)
|
|
136
|
+
assert g is not None
|
|
137
|
+
|
|
138
|
+
# 计算更新
|
|
139
|
+
state = self.state[p]
|
|
140
|
+
if 'momentum_buffer' not in state:
|
|
141
|
+
state['momentum_buffer'] = torch.zeros_like(g)
|
|
142
|
+
buf = state['momentum_buffer']
|
|
143
|
+
buf.mul_(momentum).add_(g)
|
|
144
|
+
if group['nesterov']:
|
|
145
|
+
g = g.add(buf, alpha=momentum)
|
|
146
|
+
else:
|
|
147
|
+
g = buf
|
|
148
|
+
u = zeropower_via_newtonschulz5(g, steps=group['ns_steps'])
|
|
149
|
+
|
|
150
|
+
# 缩放更新
|
|
151
|
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
|
152
|
+
|
|
153
|
+
# 应用权重衰减
|
|
154
|
+
p.data.mul_(1 - lr * wd)
|
|
155
|
+
|
|
156
|
+
# 应用更新
|
|
157
|
+
p.data.add_(u, alpha=-adjusted_lr)
|
|
158
|
+
|
|
159
|
+
############################
|
|
160
|
+
# AdamW backup #
|
|
161
|
+
############################
|
|
162
|
+
|
|
163
|
+
params = [p for p in group['params'] if not self.state[p]['use_muon']]
|
|
164
|
+
lr = group['lr']
|
|
165
|
+
beta1, beta2 = group['adamw_betas']
|
|
166
|
+
eps = group['adamw_eps']
|
|
167
|
+
weight_decay = group['wd']
|
|
168
|
+
|
|
169
|
+
for p in params:
|
|
170
|
+
g = p.grad
|
|
171
|
+
if g is None:
|
|
172
|
+
continue
|
|
173
|
+
state = self.state[p]
|
|
174
|
+
if 'step' not in state:
|
|
175
|
+
state['step'] = 0
|
|
176
|
+
state['moment1'] = torch.zeros_like(g)
|
|
177
|
+
state['moment2'] = torch.zeros_like(g)
|
|
178
|
+
state['step'] += 1
|
|
179
|
+
step = state['step']
|
|
180
|
+
buf1 = state['moment1']
|
|
181
|
+
buf2 = state['moment2']
|
|
182
|
+
buf1.lerp_(g, 1 - beta1)
|
|
183
|
+
buf2.lerp_(g.square(), 1 - beta2)
|
|
184
|
+
|
|
185
|
+
g = buf1 / (eps + buf2.sqrt())
|
|
186
|
+
|
|
187
|
+
bias_correction1 = 1 - beta1**step
|
|
188
|
+
bias_correction2 = 1 - beta2**step
|
|
189
|
+
scale = bias_correction1 / bias_correction2**0.5
|
|
190
|
+
p.data.mul_(1 - lr * weight_decay)
|
|
191
|
+
p.data.add_(g, alpha=-lr / scale)
|
|
192
|
+
|
|
193
|
+
return loss
|
orbit/optim/sam.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
class SAM(torch.optim.Optimizer):
|
|
4
|
+
'''锐度感知最小化优化器 (Sharpness-Aware Minimization)。
|
|
5
|
+
|
|
6
|
+
该优化器通过寻找损失景观中平坦的局部最小值来提高模型的泛化能力。
|
|
7
|
+
它执行两步梯度更新:首先寻找使损失最大化的扰动,然后在该点进行梯度更新。
|
|
8
|
+
|
|
9
|
+
Attributes:
|
|
10
|
+
base_optimizer (torch.optim.Optimizer): 基础优化器实例(如 SGD 或 Adam)。
|
|
11
|
+
param_groups (list): 包含优化器参数组的列表。
|
|
12
|
+
state (dict): 存储参数状态(如扰动向量 e_w)。
|
|
13
|
+
'''
|
|
14
|
+
|
|
15
|
+
def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
|
|
16
|
+
'''初始化 SAM 优化器。
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
params (iterable): 可迭代的模型参数或定义参数组的字典。
|
|
20
|
+
base_optimizer (class): 基础优化器类(注意是类名,如 torch.optim.SGD)。
|
|
21
|
+
rho (float, optional): 邻域大小,用于控制扰动范围。默认为 0.05。
|
|
22
|
+
**kwargs: 传递给基础优化器的其他超参数(如 lr, momentum, weight_decay)。
|
|
23
|
+
|
|
24
|
+
Raises:
|
|
25
|
+
AssertionError: 如果 rho 小于 0。
|
|
26
|
+
'''
|
|
27
|
+
assert rho >= 0.0, f'无效的 rho 值: {rho},必须是非负数。'
|
|
28
|
+
|
|
29
|
+
defaults = dict(rho=rho, **kwargs)
|
|
30
|
+
super(SAM, self).__init__(params, defaults)
|
|
31
|
+
|
|
32
|
+
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
|
|
33
|
+
self.param_groups = self.base_optimizer.param_groups
|
|
34
|
+
|
|
35
|
+
@torch.no_grad()
|
|
36
|
+
def first_step(self, zero_grad=False):
|
|
37
|
+
'''计算并应用参数扰动 epsilon。
|
|
38
|
+
|
|
39
|
+
该步骤根据当前梯度将模型参数 $w$ 更新为 $w + \epsilon$。
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
zero_grad (bool, optional): 是否在计算后清除梯度。默认为 False。
|
|
43
|
+
'''
|
|
44
|
+
grad_norm = self._grad_norm()
|
|
45
|
+
for group in self.param_groups:
|
|
46
|
+
scale = group['rho'] / (grad_norm + 1e-12)
|
|
47
|
+
|
|
48
|
+
for p in group['params']:
|
|
49
|
+
if p.grad is None: continue
|
|
50
|
+
e_w = p.grad * scale.to(p)
|
|
51
|
+
self.state[p]['e_w'] = e_w # 存储扰动用于第二步恢复
|
|
52
|
+
p.add_(e_w) # 参数移动到 w + e_w
|
|
53
|
+
|
|
54
|
+
if zero_grad: self.zero_grad()
|
|
55
|
+
|
|
56
|
+
@torch.no_grad()
|
|
57
|
+
def second_step(self, zero_grad=False):
|
|
58
|
+
'''恢复参数并执行真正的梯度更新。
|
|
59
|
+
|
|
60
|
+
该步骤将参数从 $w + \epsilon$ 恢复回 $w$,并利用扰动位置的梯度更新 $w$。
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
zero_grad (bool, optional): 是否在更新后清除梯度。默认为 False。
|
|
64
|
+
'''
|
|
65
|
+
for group in self.param_groups:
|
|
66
|
+
for p in group['params']:
|
|
67
|
+
if p.grad is None: continue
|
|
68
|
+
p.sub_(self.state[p]['e_w']) # 恢复到原始参数 w
|
|
69
|
+
|
|
70
|
+
self.base_optimizer.step() # 使用 w + e_w 处的梯度进行实际更新
|
|
71
|
+
|
|
72
|
+
if zero_grad: self.zero_grad()
|
|
73
|
+
|
|
74
|
+
def _grad_norm(self):
|
|
75
|
+
'''计算所有参数梯度的全局 L2 范数。
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
torch.Tensor: 梯度的 L2 范数标量。
|
|
79
|
+
'''
|
|
80
|
+
shared_device = self.param_groups[0]['params'][0].device
|
|
81
|
+
norm = torch.norm(
|
|
82
|
+
torch.stack([
|
|
83
|
+
p.grad.norm(p=2).to(shared_device)
|
|
84
|
+
for group in self.param_groups for p in group['params']
|
|
85
|
+
if p.grad is not None
|
|
86
|
+
]),
|
|
87
|
+
p=2
|
|
88
|
+
)
|
|
89
|
+
return norm
|
|
90
|
+
|
|
91
|
+
def step(self, closure=None):
|
|
92
|
+
raise NotImplementedError('SAM requires steps to be run manually: first_step and second_step')
|
orbit/plugin/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
try: from orbit.plugin.classification import ClassificationReport
|
|
2
|
+
except: pass
|
|
3
|
+
|
|
4
|
+
from orbit.plugin.warmup import Warmup
|
|
5
|
+
from orbit.plugin.early_stopping import EarlyStopping
|
|
6
|
+
from orbit.plugin.gradient_accumulation import GradientAccumulation
|
|
7
|
+
from orbit.plugin.mentor import Mentor
|
|
8
|
+
from orbit.plugin.ema import EMA # Not tested
|
|
9
|
+
from orbit.plugin.memory_estimator import MemoryEstimator
|
|
10
|
+
from orbit.plugin.overfit import Overfit
|
orbit/plugin/board.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
2
|
+
from typing import Optional, TYPE_CHECKING
|
|
3
|
+
from orbit.callback import Callback, Event
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from ..engine import Engine
|
|
7
|
+
|
|
8
|
+
class Board(Callback):
|
|
9
|
+
def __init__(self, name: str, log_dir: str):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.log_dir = log_dir + '/' + name
|
|
12
|
+
self.writer: Optional[SummaryWriter] = None
|
|
13
|
+
|
|
14
|
+
def on_init(self, event: Event):
|
|
15
|
+
'''初始化 SummaryWriter'''
|
|
16
|
+
# 如果 log_dir 不存在会自动创建
|
|
17
|
+
self.writer = SummaryWriter(log_dir=self.log_dir)
|
|
18
|
+
event.engine.writer = self.writer
|
|
19
|
+
event.engine.print(f'[cyan]Initialized. Log dir: {self.log_dir}[/]', plugin='Board')
|
|
20
|
+
|
|
21
|
+
def on_batch_end(self, event: Event):
|
|
22
|
+
'''
|
|
23
|
+
每个 Batch 结束时记录:
|
|
24
|
+
1. Train Loss (Batch级)
|
|
25
|
+
2. Learning Rate
|
|
26
|
+
'''
|
|
27
|
+
engine = event.engine
|
|
28
|
+
if engine.state == 'TRAIN':
|
|
29
|
+
# 记录 Training Loss
|
|
30
|
+
if engine.loss is not None:
|
|
31
|
+
self.writer.add_scalar('Train/Batch_Loss', engine.loss.item(), engine.global_step)
|
|
32
|
+
|
|
33
|
+
# 记录 Learning Rate (取第一个参数组)
|
|
34
|
+
if engine.optimizer:
|
|
35
|
+
current_lr = engine.optimizer.param_groups[0]['lr']
|
|
36
|
+
self.writer.add_scalar('Train/LR', current_lr, engine.global_step)
|
|
37
|
+
|
|
38
|
+
def on_epoch_end(self, event: Event):
|
|
39
|
+
'''
|
|
40
|
+
每 Epoch 结束时记录:
|
|
41
|
+
1. Epoch 平均 Loss (Train & Val)
|
|
42
|
+
2. 其他 Metrics (如果在 engine.metrics 字典里有的话)
|
|
43
|
+
'''
|
|
44
|
+
engine = event.engine
|
|
45
|
+
# engine.metrics: {'train_loss': 0.5, 'val_loss': 0.4, 'acc': 0.9}
|
|
46
|
+
for key, value in engine.metrics.items():
|
|
47
|
+
if 'loss' in key.lower():
|
|
48
|
+
tag = f'Loss/{key}'
|
|
49
|
+
elif 'acc' in key.lower():
|
|
50
|
+
tag = f'Accuracy/{key}'
|
|
51
|
+
else:
|
|
52
|
+
tag = f'Metrics/{key}'
|
|
53
|
+
|
|
54
|
+
self.writer.add_scalar(tag, value, engine.epoch + 1)
|
|
55
|
+
|
|
56
|
+
self.writer.flush()
|
|
57
|
+
|
|
58
|
+
def on_train_end(self, event: Event):
|
|
59
|
+
'''训练结束关闭 Writer'''
|
|
60
|
+
if self.writer:
|
|
61
|
+
self.writer.close()
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from typing import TYPE_CHECKING, List, Tuple
|
|
4
|
+
from orbit.callback import Callback, Event
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from orbit.engine import Engine
|
|
8
|
+
|
|
9
|
+
class Checkpoint(Callback):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
name: str,
|
|
13
|
+
path: str,
|
|
14
|
+
save_weights_only: bool = False,
|
|
15
|
+
monitor: str = 'val_loss', # 默认监控 val_loss
|
|
16
|
+
mode: str = 'min', # 默认 loss 越小越好
|
|
17
|
+
save_top_k: int = 1,
|
|
18
|
+
save_last: bool = True,
|
|
19
|
+
every_n_train_steps: int = None,
|
|
20
|
+
verbose: bool = True
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Args:
|
|
24
|
+
name (str): 模型名称前缀。
|
|
25
|
+
path (str): 保存目录。
|
|
26
|
+
save_weights_only (bool): 是否只保存模型权重 (不保存 optimizer 等状态)。
|
|
27
|
+
monitor (str): 监控指标 (例如 'val_loss', 'val_acc')。默认 'val_loss'。
|
|
28
|
+
mode (str): 'min' (越小越好) 或 'max' (越大越好)。
|
|
29
|
+
save_top_k (int): 保存最好的 K 个模型。设为 0 则禁用 Top-K 保存。
|
|
30
|
+
save_last (bool): 是否总是保存 '{name}_last.pt'。
|
|
31
|
+
every_n_train_steps (int): 每隔多少个训练步保存一次。
|
|
32
|
+
verbose (bool): 是否打印保存信息。
|
|
33
|
+
"""
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.name = name
|
|
36
|
+
self.path = path
|
|
37
|
+
self.save_weights_only = save_weights_only
|
|
38
|
+
self.monitor = monitor
|
|
39
|
+
self.mode = mode
|
|
40
|
+
self.save_top_k = save_top_k
|
|
41
|
+
self.save_last = save_last
|
|
42
|
+
self.every_n_train_steps = every_n_train_steps
|
|
43
|
+
self.verbose = verbose
|
|
44
|
+
|
|
45
|
+
# 维护 Top-K 模型列表: [(score, filename), ...]
|
|
46
|
+
self.best_k_models: List[Tuple[float, str]] = []
|
|
47
|
+
|
|
48
|
+
# 记录上一个 Step Checkpoint 文件名,用于删除
|
|
49
|
+
self.last_step_checkpoint: str = None
|
|
50
|
+
|
|
51
|
+
# 内部状态 Key
|
|
52
|
+
self._meta_key = 'checkpoint_callback'
|
|
53
|
+
|
|
54
|
+
def on_init(self, event: Event):
|
|
55
|
+
"""
|
|
56
|
+
1. 创建文件夹
|
|
57
|
+
2. 尝试恢复 Checkpoint 状态 (best_k_models)
|
|
58
|
+
3. 尝试加载 last checkpoint
|
|
59
|
+
"""
|
|
60
|
+
engine = event.engine
|
|
61
|
+
if not os.path.exists(self.path):
|
|
62
|
+
os.makedirs(self.path, exist_ok=True)
|
|
63
|
+
|
|
64
|
+
# 尝试恢复 best_k_models 状态
|
|
65
|
+
if self._meta_key in engine.meta:
|
|
66
|
+
self.best_k_models = engine.meta[self._meta_key].get('best_k_models', [])
|
|
67
|
+
|
|
68
|
+
load_path = os.path.join(self.path, self.name + "_last.pt").replace("\\", "/")
|
|
69
|
+
|
|
70
|
+
if os.path.exists(load_path):
|
|
71
|
+
self._load(engine, load_path)
|
|
72
|
+
else:
|
|
73
|
+
engine.print(f"[yellow]Warning: Resume checkpoint '{load_path}' not found. Starting from scratch.[/]", plugin='Checkpointing')
|
|
74
|
+
|
|
75
|
+
def on_batch_end(self, event: Event):
|
|
76
|
+
"""
|
|
77
|
+
每个 Batch 结束时:
|
|
78
|
+
检查是否需要按 Step 保存
|
|
79
|
+
"""
|
|
80
|
+
if self.every_n_train_steps and event.engine.state == "TRAIN":
|
|
81
|
+
step = event.engine.global_step
|
|
82
|
+
if step > 0 and step % self.every_n_train_steps == 0:
|
|
83
|
+
# 保存 step checkpoint
|
|
84
|
+
filename = f"{self.name}_step_{step}.pt"
|
|
85
|
+
|
|
86
|
+
# 传递 is_step=True
|
|
87
|
+
self._save(event.engine, filename, verbose=self.verbose, is_step=True)
|
|
88
|
+
|
|
89
|
+
# 删除旧的 step checkpoint
|
|
90
|
+
if self.last_step_checkpoint and self.last_step_checkpoint != filename:
|
|
91
|
+
self._remove(event.engine, self.last_step_checkpoint)
|
|
92
|
+
self.last_step_checkpoint = filename
|
|
93
|
+
|
|
94
|
+
# 同时更新 last checkpoint
|
|
95
|
+
if self.save_last:
|
|
96
|
+
self._save(event.engine, f"{self.name}_last.pt", verbose=False, is_step=True)
|
|
97
|
+
|
|
98
|
+
def on_epoch_end(self, event: Event):
|
|
99
|
+
"""
|
|
100
|
+
每个 Epoch 结束时:
|
|
101
|
+
1. 保存 last
|
|
102
|
+
2. 如果设置了 monitor,保存 top_k
|
|
103
|
+
"""
|
|
104
|
+
engine = event.engine
|
|
105
|
+
# 1. Save Last
|
|
106
|
+
if self.save_last:
|
|
107
|
+
self._save(engine, f"{self.name}_last.pt", verbose=False) # last 不需要每次都啰嗦
|
|
108
|
+
|
|
109
|
+
# 2. Save Top K
|
|
110
|
+
if self.monitor and self.save_top_k > 0:
|
|
111
|
+
current_score = engine.metrics.get(self.monitor)
|
|
112
|
+
|
|
113
|
+
if current_score is None:
|
|
114
|
+
if self.verbose:
|
|
115
|
+
engine.print(f"[yellow]Metric '{self.monitor}' not found in metrics. Skipping Top-K save.[/]", plugin='Checkpointing')
|
|
116
|
+
return
|
|
117
|
+
|
|
118
|
+
self._check_and_save_top_k(engine, current_score)
|
|
119
|
+
|
|
120
|
+
def _check_and_save_top_k(self, engine: 'Engine', current_score: float):
|
|
121
|
+
"""检查并保存 Top-K 模型"""
|
|
122
|
+
filename = f"{self.name}_ep{engine.epoch+1}_{self.monitor}_{current_score:.4f}.pt"
|
|
123
|
+
|
|
124
|
+
# 逻辑简化:总是先加入,然后排序,如果超过 K 个,删除最差的
|
|
125
|
+
self.best_k_models.append((current_score, filename))
|
|
126
|
+
|
|
127
|
+
# 排序
|
|
128
|
+
reverse = (self.mode == 'max')
|
|
129
|
+
self.best_k_models.sort(key=lambda x: x[0], reverse=reverse)
|
|
130
|
+
|
|
131
|
+
# 如果列表过长,处理溢出
|
|
132
|
+
if len(self.best_k_models) > self.save_top_k:
|
|
133
|
+
worst_model = self.best_k_models.pop() # 移除最后一个(最差的)
|
|
134
|
+
worst_score, worst_filename = worst_model
|
|
135
|
+
|
|
136
|
+
# 如果刚才加入的就是最差的,说明没进 Top K,不需要保存
|
|
137
|
+
if worst_filename == filename:
|
|
138
|
+
return
|
|
139
|
+
|
|
140
|
+
# 否则,保存新的,删除旧的最差的
|
|
141
|
+
self._save(engine, filename, verbose=self.verbose)
|
|
142
|
+
self._remove(engine, worst_filename)
|
|
143
|
+
else:
|
|
144
|
+
# 列表没满,直接保存
|
|
145
|
+
self._save(engine, filename, verbose=self.verbose)
|
|
146
|
+
|
|
147
|
+
# 更新 Meta 状态
|
|
148
|
+
engine.meta[self._meta_key] = {'best_k_models': self.best_k_models}
|
|
149
|
+
|
|
150
|
+
def _save(self, engine: 'Engine', filename: str, verbose: bool = True, is_step: bool = False):
|
|
151
|
+
# 确保 meta 数据是最新的
|
|
152
|
+
if self.monitor:
|
|
153
|
+
engine.meta[self._meta_key] = {'best_k_models': self.best_k_models}
|
|
154
|
+
|
|
155
|
+
# 获取原始模型 (去除 DataParallel 包装) 以保证 Checkpoint 通用性
|
|
156
|
+
raw_model = engine.unwrap_model()
|
|
157
|
+
|
|
158
|
+
state = {
|
|
159
|
+
'epoch': engine.epoch,
|
|
160
|
+
'global_step': engine.global_step,
|
|
161
|
+
'batch_idx': engine.batch_idx,
|
|
162
|
+
'is_step': is_step,
|
|
163
|
+
'model_state_dict': raw_model.state_dict(),
|
|
164
|
+
'optimizer_state_dict': engine.optimizer.state_dict() if engine.optimizer else None,
|
|
165
|
+
'scheduler_state_dict': engine.scheduler.state_dict() if engine.scheduler else None,
|
|
166
|
+
'scaler_state_dict': engine.scaler.state_dict() if engine.scaler else None,
|
|
167
|
+
'meta': engine.meta,
|
|
168
|
+
}
|
|
169
|
+
if self.save_weights_only:
|
|
170
|
+
state = raw_model.state_dict()
|
|
171
|
+
|
|
172
|
+
file_path = os.path.join(self.path, filename)
|
|
173
|
+
try:
|
|
174
|
+
torch.save(state, file_path)
|
|
175
|
+
if verbose:
|
|
176
|
+
# 使用相对路径显示,更简洁
|
|
177
|
+
rel_path = os.path.relpath(file_path)
|
|
178
|
+
engine.print(f"Saved checkpoint: {rel_path}", plugin='Checkpointing')
|
|
179
|
+
except Exception as e:
|
|
180
|
+
engine.print(f"[red]Failed to save checkpoint: {e}[/]", plugin='Checkpointing')
|
|
181
|
+
|
|
182
|
+
def _remove(self, engine: 'Engine', filename: str):
|
|
183
|
+
"""删除旧的 Checkpoint 文件"""
|
|
184
|
+
file_path = os.path.join(self.path, filename)
|
|
185
|
+
if os.path.exists(file_path):
|
|
186
|
+
try:
|
|
187
|
+
os.remove(file_path)
|
|
188
|
+
if self.verbose:
|
|
189
|
+
engine.print(f"[dim]Removed old checkpoint: {filename}[/]", plugin='Checkpointing')
|
|
190
|
+
except OSError as e:
|
|
191
|
+
engine.print(f"[red]Failed to remove checkpoint {filename}: {e}[/]", plugin='Checkpointing')
|
|
192
|
+
|
|
193
|
+
def _load(self, engine: 'Engine', file_path: str):
|
|
194
|
+
"""加载 Checkpoint 的核心逻辑"""
|
|
195
|
+
engine.print(f"[cyan]Loading checkpoint from: {file_path}[/]", plugin='Checkpointing')
|
|
196
|
+
try:
|
|
197
|
+
# 加载到设备
|
|
198
|
+
checkpoint = torch.load(file_path, map_location=engine.device)
|
|
199
|
+
|
|
200
|
+
# 获取原始模型以进行加载
|
|
201
|
+
raw_model = engine.unwrap_model()
|
|
202
|
+
|
|
203
|
+
# 1. 加载模型权重
|
|
204
|
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
205
|
+
raw_model.load_state_dict(checkpoint['model_state_dict'])
|
|
206
|
+
else:
|
|
207
|
+
raw_model.load_state_dict(checkpoint)
|
|
208
|
+
engine.print("[yellow]Loaded model weights only (legacy format).[/]", plugin='Checkpointing')
|
|
209
|
+
return
|
|
210
|
+
|
|
211
|
+
# 2. 恢复训练状态
|
|
212
|
+
if not self.save_weights_only:
|
|
213
|
+
if engine.optimizer and 'optimizer_state_dict' in checkpoint:
|
|
214
|
+
engine.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
215
|
+
|
|
216
|
+
if engine.scheduler and 'scheduler_state_dict' in checkpoint:
|
|
217
|
+
engine.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
218
|
+
|
|
219
|
+
if engine.scaler and 'scaler_state_dict' in checkpoint:
|
|
220
|
+
engine.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
221
|
+
|
|
222
|
+
if 'meta' in checkpoint:
|
|
223
|
+
engine.meta.update(checkpoint['meta'])
|
|
224
|
+
|
|
225
|
+
loaded_epoch = checkpoint.get('epoch', 0)
|
|
226
|
+
loaded_batch_idx = checkpoint.get('batch_idx', -1)
|
|
227
|
+
is_step = checkpoint.get('is_step', False)
|
|
228
|
+
|
|
229
|
+
if is_step:
|
|
230
|
+
# 如果是 Step Checkpoint,从当前 Epoch 的下一个 Batch 继续
|
|
231
|
+
engine.start_epoch = loaded_epoch
|
|
232
|
+
engine.start_batch_idx = loaded_batch_idx
|
|
233
|
+
msg = f"Epoch {engine.start_epoch}, Batch {engine.start_batch_idx + 1}"
|
|
234
|
+
else:
|
|
235
|
+
# 如果是 Epoch Checkpoint,从下一个 Epoch 开始
|
|
236
|
+
engine.start_epoch = loaded_epoch + 1
|
|
237
|
+
engine.start_batch_idx = -1
|
|
238
|
+
msg = f"Epoch {engine.start_epoch}"
|
|
239
|
+
|
|
240
|
+
engine.global_step = checkpoint.get('global_step', 0)
|
|
241
|
+
|
|
242
|
+
engine.print(f"[green]Successfully resumed training from {msg}, Global Step {engine.global_step}[/]", plugin='Checkpointing')
|
|
243
|
+
|
|
244
|
+
except Exception as e:
|
|
245
|
+
engine.print(f"[red]Failed to load checkpoint: {e}[/]", plugin='Checkpointing')
|