orbit-torch 0.0.4a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -0
- orbit/callback.py +54 -0
- orbit/engine.py +802 -0
- orbit/optim/__init__.py +2 -0
- orbit/optim/muon.py +193 -0
- orbit/optim/sam.py +92 -0
- orbit/plugin/__init__.py +10 -0
- orbit/plugin/board.py +61 -0
- orbit/plugin/checkpoint.py +245 -0
- orbit/plugin/classification.py +190 -0
- orbit/plugin/data/mentor_i18n.json +102 -0
- orbit/plugin/display_model.py +75 -0
- orbit/plugin/early_stopping.py +101 -0
- orbit/plugin/ema.py +97 -0
- orbit/plugin/gradient_accumulation.py +32 -0
- orbit/plugin/memory_estimator.py +234 -0
- orbit/plugin/mentor.py +313 -0
- orbit/plugin/overfit.py +30 -0
- orbit/plugin/warmup.py +119 -0
- orbit/utils/__init__.py +29 -0
- orbit/utils/freeze.py +59 -0
- orbit/utils/initialization.py +501 -0
- orbit/utils/layer_io.py +55 -0
- orbit/utils/mask.py +92 -0
- orbit/utils/seed.py +66 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +25 -0
- orbit_torch-0.0.4a1.dist-info/RECORD +29 -0
- orbit_torch-0.0.4a1.dist-info/WHEEL +5 -0
- orbit_torch-0.0.4a1.dist-info/top_level.txt +1 -0
orbit/utils/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from .initialization import (
|
|
2
|
+
trunc_normal_,
|
|
3
|
+
constant_init,
|
|
4
|
+
init_weights,
|
|
5
|
+
init_layer_norm,
|
|
6
|
+
init_embedding,
|
|
7
|
+
init_weights_transformer,
|
|
8
|
+
WeightInitializer,
|
|
9
|
+
initialize_weights,
|
|
10
|
+
AutoInitializer,
|
|
11
|
+
auto_initialize
|
|
12
|
+
)
|
|
13
|
+
from .freeze import (
|
|
14
|
+
set_trainable,
|
|
15
|
+
freeze_layers,
|
|
16
|
+
unfreeze_layers,
|
|
17
|
+
get_trainable_params
|
|
18
|
+
)
|
|
19
|
+
from .seed import (
|
|
20
|
+
seed_everything,
|
|
21
|
+
worker_init_fn,
|
|
22
|
+
create_generator
|
|
23
|
+
)
|
|
24
|
+
from .mask import (
|
|
25
|
+
make_padding_mask,
|
|
26
|
+
make_lookahead_mask,
|
|
27
|
+
make_causal_mask,
|
|
28
|
+
make_sliding_window_mask
|
|
29
|
+
)
|
orbit/utils/freeze.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Union, List, Optional, Iterable
|
|
4
|
+
|
|
5
|
+
def set_trainable(
|
|
6
|
+
model: nn.Module,
|
|
7
|
+
targets: Optional[Union[str, List[str]]] = None,
|
|
8
|
+
trainable: bool = False
|
|
9
|
+
) -> None:
|
|
10
|
+
'''设置模型参数的 requires_grad 属性,用于冻结或解冻层。
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
model (nn.Module): 目标模型。
|
|
14
|
+
targets (str or List[str], optional): 要操作的层名称或参数名称模式。
|
|
15
|
+
- 如果为 None,则操作模型的所有参数。
|
|
16
|
+
- 如果为 str,则操作名称中包含该字符串的所有参数。
|
|
17
|
+
- 如果为 List[str],则操作名称中包含列表中任意字符串的所有参数。
|
|
18
|
+
trainable (bool): 是否可训练 (True 为解冻, False 为冻结)。
|
|
19
|
+
'''
|
|
20
|
+
if targets is None:
|
|
21
|
+
for param in model.parameters():
|
|
22
|
+
param.requires_grad = trainable
|
|
23
|
+
else:
|
|
24
|
+
if isinstance(targets, str):
|
|
25
|
+
targets = [targets]
|
|
26
|
+
|
|
27
|
+
for name, param in model.named_parameters():
|
|
28
|
+
# 检查参数名是否包含 targets 中的任何一个模式
|
|
29
|
+
if any(t in name for t in targets):
|
|
30
|
+
param.requires_grad = trainable
|
|
31
|
+
|
|
32
|
+
def freeze_layers(model: nn.Module, targets: Optional[Union[str, List[str]]] = None) -> None:
|
|
33
|
+
'''冻结模型指定层或所有层 (requires_grad=False)。
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
model (nn.Module): 目标模型。
|
|
37
|
+
targets (str or List[str], optional): 要冻结的层名称模式。如果不指定,则冻结整个模型。
|
|
38
|
+
'''
|
|
39
|
+
set_trainable(model, targets, trainable=False)
|
|
40
|
+
|
|
41
|
+
def unfreeze_layers(model: nn.Module, targets: Optional[Union[str, List[str]]] = None) -> None:
|
|
42
|
+
'''解冻模型指定层或所有层 (requires_grad=True)。
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
model (nn.Module): 目标模型。
|
|
46
|
+
targets (str or List[str], optional): 要解冻的层名称模式。如果不指定,则解冻整个模型。
|
|
47
|
+
'''
|
|
48
|
+
set_trainable(model, targets, trainable=True)
|
|
49
|
+
|
|
50
|
+
def get_trainable_params(model: nn.Module) -> Iterable[torch.Tensor]:
|
|
51
|
+
'''获取模型中 requires_grad=True 的参数,供优化器使用。
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
model (nn.Module): 目标模型。
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Iterable[torch.Tensor]: 可训练参数的迭代器。
|
|
58
|
+
'''
|
|
59
|
+
return filter(lambda p: p.requires_grad, model.parameters())
|
|
@@ -0,0 +1,501 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import warnings
|
|
3
|
+
import re
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from torch.nn.init import _calculate_fan_in_and_fan_out
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from rich.console import Console
|
|
10
|
+
from rich.table import Table
|
|
11
|
+
RICH_AVAILABLE = True
|
|
12
|
+
except ImportError:
|
|
13
|
+
RICH_AVAILABLE = False
|
|
14
|
+
|
|
15
|
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
|
16
|
+
'''截断正态分布初始化的辅助函数,在无梯度模式下运行。
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
tensor (torch.Tensor): 要初始化的张量。
|
|
20
|
+
mean (float): 正态分布的均值。
|
|
21
|
+
std (float): 正态分布的标准差。
|
|
22
|
+
a (float): 截断下界。
|
|
23
|
+
b (float): 截断上界。
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
torch.Tensor: 初始化后的张量。
|
|
27
|
+
'''
|
|
28
|
+
def norm_cdf(x):
|
|
29
|
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
|
30
|
+
|
|
31
|
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
|
32
|
+
warnings.warn(
|
|
33
|
+
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
|
34
|
+
'The distribution of values may be incorrect.',
|
|
35
|
+
stacklevel=2
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
with torch.no_grad():
|
|
39
|
+
l = norm_cdf((a - mean) / std)
|
|
40
|
+
u = norm_cdf((b - mean) / std)
|
|
41
|
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
|
42
|
+
tensor.erfinv_()
|
|
43
|
+
tensor.mul_(std * math.sqrt(2.))
|
|
44
|
+
tensor.add_(mean)
|
|
45
|
+
tensor.clamp_(min=a, max=b)
|
|
46
|
+
return tensor
|
|
47
|
+
|
|
48
|
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
|
49
|
+
'''使用截断正态分布填充输入张量。
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
tensor (torch.Tensor): 要填充的 n 维 torch.Tensor。
|
|
53
|
+
mean (float): 正态分布的均值。
|
|
54
|
+
std (float): 正态分布的标准差。
|
|
55
|
+
a (float): 最小截止值。
|
|
56
|
+
b (float): 最大截止值。
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
torch.Tensor: 修改后的张量。
|
|
60
|
+
'''
|
|
61
|
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
|
62
|
+
|
|
63
|
+
def constant_init(module, val, bias=0):
|
|
64
|
+
'''使用常数值初始化模块权重,可选初始化偏置。
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
module (nn.Module): 要初始化的模块。
|
|
68
|
+
val (float): 权重的常数值。
|
|
69
|
+
bias (float): 偏置的常数值。
|
|
70
|
+
'''
|
|
71
|
+
if hasattr(module, 'weight') and module.weight is not None:
|
|
72
|
+
nn.init.constant_(module.weight, val)
|
|
73
|
+
if hasattr(module, 'bias') and module.bias is not None:
|
|
74
|
+
nn.init.constant_(module.bias, bias)
|
|
75
|
+
|
|
76
|
+
def init_weights(module, method='kaiming', distribution='normal', bias=0,
|
|
77
|
+
a=0, mode='fan_out', nonlinearity='relu',
|
|
78
|
+
gain=1,
|
|
79
|
+
std=0.02, trunc_a=-2., trunc_b=2.):
|
|
80
|
+
'''通用权重初始化函数,支持多种初始化方法。
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
module (nn.Module): 要初始化的模块。
|
|
84
|
+
method (str): 初始化方法。可选值:
|
|
85
|
+
- 'kaiming': Kaiming (He) 初始化
|
|
86
|
+
- 'xavier': Xavier (Glorot) 初始化
|
|
87
|
+
- 'c2_xavier': Caffe2 风格的 Xavier 初始化
|
|
88
|
+
- 'orthogonal': 正交初始化
|
|
89
|
+
- 'trunc_normal': 截断正态分布
|
|
90
|
+
- 'normal': 标准正态分布
|
|
91
|
+
- 'constant': 常数初始化 (使用 val=gain)
|
|
92
|
+
distribution (str): 'uniform' 或 'normal' (用于 kaiming 和 xavier)。
|
|
93
|
+
bias (float): 偏置的初始化值。
|
|
94
|
+
a (float): Kaiming init 的负斜率。
|
|
95
|
+
mode (str): Kaiming init 的模式 ('fan_in', 'fan_out')。
|
|
96
|
+
nonlinearity (str): Kaiming init 的非线性函数 ('relu', 'leaky_relu' 等)。
|
|
97
|
+
gain (float): Xavier init 的缩放因子,或 Constant init 的值。
|
|
98
|
+
std (float): Normal/Truncated Normal 的标准差。
|
|
99
|
+
trunc_a (float): Truncated Normal 的下界。
|
|
100
|
+
trunc_b (float): Truncated Normal 的上界。
|
|
101
|
+
'''
|
|
102
|
+
if hasattr(module, 'weight') and module.weight is not None:
|
|
103
|
+
if method == 'kaiming':
|
|
104
|
+
if distribution == 'uniform':
|
|
105
|
+
nn.init.kaiming_uniform_(
|
|
106
|
+
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
107
|
+
else:
|
|
108
|
+
nn.init.kaiming_normal_(
|
|
109
|
+
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
110
|
+
|
|
111
|
+
elif method == 'xavier':
|
|
112
|
+
if distribution == 'uniform':
|
|
113
|
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
114
|
+
else:
|
|
115
|
+
nn.init.xavier_normal_(module.weight, gain=gain)
|
|
116
|
+
|
|
117
|
+
elif method == 'c2_xavier':
|
|
118
|
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(module.weight)
|
|
119
|
+
c2_std = math.sqrt(1.0 / float(fan_in))
|
|
120
|
+
nn.init.normal_(module.weight, mean=0.0, std=c2_std)
|
|
121
|
+
|
|
122
|
+
elif method == 'orthogonal':
|
|
123
|
+
nn.init.orthogonal_(module.weight, gain=gain)
|
|
124
|
+
|
|
125
|
+
elif method == 'trunc_normal':
|
|
126
|
+
trunc_normal_(
|
|
127
|
+
module.weight, mean=0., std=std, a=trunc_a, b=trunc_b)
|
|
128
|
+
|
|
129
|
+
elif method == 'normal':
|
|
130
|
+
nn.init.normal_(module.weight, mean=0., std=std)
|
|
131
|
+
|
|
132
|
+
elif method == 'constant':
|
|
133
|
+
nn.init.constant_(module.weight, val=gain)
|
|
134
|
+
|
|
135
|
+
else:
|
|
136
|
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
137
|
+
|
|
138
|
+
if hasattr(module, 'bias') and module.bias is not None:
|
|
139
|
+
nn.init.constant_(module.bias, bias)
|
|
140
|
+
|
|
141
|
+
def init_layer_norm(module, weight=1.0, bias=0.0):
|
|
142
|
+
'''初始化 LayerNorm 或 GroupNorm 模块。
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
module (nn.Module): 归一化模块。
|
|
146
|
+
weight (float): 权重的初始值 (gamma)。
|
|
147
|
+
bias (float): 偏置的初始值 (beta)。
|
|
148
|
+
'''
|
|
149
|
+
if hasattr(module, 'weight') and module.weight is not None:
|
|
150
|
+
nn.init.constant_(module.weight, weight)
|
|
151
|
+
if hasattr(module, 'bias') and module.bias is not None:
|
|
152
|
+
nn.init.constant_(module.bias, bias)
|
|
153
|
+
|
|
154
|
+
def init_embedding(module, init_method='normal', std=0.02, a=0., b=1., padding_idx=None):
|
|
155
|
+
'''初始化 Embedding 层。
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
module (nn.Embedding): Embedding 模块。
|
|
159
|
+
init_method (str): 'normal', 'trunc_normal', 'uniform'。
|
|
160
|
+
std (float): 正态分布的标准差。
|
|
161
|
+
a (float): 均匀分布的下界或截断正态分布的下界。
|
|
162
|
+
b (float): 均匀分布的上界或截断正态分布的上界。
|
|
163
|
+
padding_idx (int, optional): 如果指定,padding 索引的权重将被初始化为 0。
|
|
164
|
+
'''
|
|
165
|
+
if hasattr(module, 'weight') and module.weight is not None:
|
|
166
|
+
if init_method == 'normal':
|
|
167
|
+
nn.init.normal_(module.weight, mean=0., std=std)
|
|
168
|
+
elif init_method == 'trunc_normal':
|
|
169
|
+
trunc_normal_(module.weight, mean=0., std=std, a=a, b=b)
|
|
170
|
+
elif init_method == 'uniform':
|
|
171
|
+
nn.init.uniform_(module.weight, a=a, b=b)
|
|
172
|
+
|
|
173
|
+
if padding_idx is not None:
|
|
174
|
+
module.weight.data[padding_idx].zero_()
|
|
175
|
+
|
|
176
|
+
def init_weights_transformer(model, n_layer=None, initializer_range=0.02,
|
|
177
|
+
residual_proj_names=('linear_out', 'fc2', 'c_proj'),
|
|
178
|
+
verbose=False):
|
|
179
|
+
'''Transformer 模型的通用初始化逻辑,支持复杂的嵌套结构和残差缩放。
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
model (nn.Module): 要初始化的模型。
|
|
183
|
+
n_layer (int, optional): Transformer 的层数,用于残差缩放。
|
|
184
|
+
initializer_range (float): 初始化的标准差。
|
|
185
|
+
residual_proj_names (tuple): 需要应用残差缩放的模块名称关键字。
|
|
186
|
+
verbose (bool): 是否打印初始化信息。
|
|
187
|
+
'''
|
|
188
|
+
init_info = []
|
|
189
|
+
|
|
190
|
+
for name, module in model.named_modules():
|
|
191
|
+
if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d)):
|
|
192
|
+
nn.init.normal_(module.weight, mean=0.0, std=initializer_range)
|
|
193
|
+
if module.bias is not None:
|
|
194
|
+
nn.init.zeros_(module.bias)
|
|
195
|
+
|
|
196
|
+
info = f'Normal (std={initializer_range})'
|
|
197
|
+
if n_layer is not None:
|
|
198
|
+
module_name = name.split('.')[-1]
|
|
199
|
+
if any(proj_name in module_name for proj_name in residual_proj_names):
|
|
200
|
+
scale = 1.0 / math.sqrt(2.0 * n_layer)
|
|
201
|
+
module.weight.data.mul_(scale)
|
|
202
|
+
info = f'Residual Scaled (scale={scale:.4f})'
|
|
203
|
+
|
|
204
|
+
init_info.append((name, module.__class__.__name__, info))
|
|
205
|
+
|
|
206
|
+
elif isinstance(module, nn.Embedding):
|
|
207
|
+
nn.init.normal_(module.weight, mean=0.0, std=initializer_range)
|
|
208
|
+
if hasattr(module, 'padding_idx') and module.padding_idx is not None:
|
|
209
|
+
module.weight.data[module.padding_idx].zero_()
|
|
210
|
+
init_info.append((name, 'Embedding', f'Normal (std={initializer_range})'))
|
|
211
|
+
|
|
212
|
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
|
213
|
+
nn.init.ones_(module.weight)
|
|
214
|
+
nn.init.zeros_(module.bias)
|
|
215
|
+
init_info.append((name, module.__class__.__name__, 'Ones/Zeros'))
|
|
216
|
+
|
|
217
|
+
if verbose:
|
|
218
|
+
_print_init_info(init_info)
|
|
219
|
+
|
|
220
|
+
class WeightInitializer:
|
|
221
|
+
'''多功能权重初始化器类,支持复杂的初始化策略配置。
|
|
222
|
+
|
|
223
|
+
Attributes:
|
|
224
|
+
method (str): 主要初始化方法。
|
|
225
|
+
distribution (str): 分布类型。
|
|
226
|
+
init_bias (float): 偏置初始值。
|
|
227
|
+
init_norm_weight (float): 归一化层权重初始值。
|
|
228
|
+
init_norm_bias (float): 归一化层偏置初始值。
|
|
229
|
+
'''
|
|
230
|
+
|
|
231
|
+
def __init__(
|
|
232
|
+
self,
|
|
233
|
+
method='kaiming',
|
|
234
|
+
distribution='normal',
|
|
235
|
+
mode='fan_out',
|
|
236
|
+
nonlinearity='relu',
|
|
237
|
+
init_bias=0.0,
|
|
238
|
+
init_norm_weight=1.0,
|
|
239
|
+
init_norm_bias=0.0,
|
|
240
|
+
std=0.02,
|
|
241
|
+
trunc_a=-2.0,
|
|
242
|
+
trunc_b=2.0
|
|
243
|
+
):
|
|
244
|
+
'''初始化 WeightInitializer。
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
method (str): 初始化方法,默认为 'kaiming'。
|
|
248
|
+
distribution (str): 'uniform' 或 'normal'。
|
|
249
|
+
mode (str): 'fan_in' 或 'fan_out' (用于 kaiming)。
|
|
250
|
+
nonlinearity (str): 非线性函数名 (用于 kaiming)。
|
|
251
|
+
init_bias (float): 线性/卷积层的偏置初始值。
|
|
252
|
+
init_norm_weight (float): Norm 层的权重初始值。
|
|
253
|
+
init_norm_bias (float): Norm 层的偏置初始值。
|
|
254
|
+
std (float): 用于 'normal' 或 'trunc_normal' 的标准差。
|
|
255
|
+
trunc_a (float): 截断正态分布下界。
|
|
256
|
+
trunc_b (float): 截断正态分布上界。
|
|
257
|
+
'''
|
|
258
|
+
self.method = method
|
|
259
|
+
self.distribution = distribution
|
|
260
|
+
self.mode = mode
|
|
261
|
+
self.nonlinearity = nonlinearity
|
|
262
|
+
self.init_bias = init_bias
|
|
263
|
+
self.init_norm_weight = init_norm_weight
|
|
264
|
+
self.init_norm_bias = init_norm_bias
|
|
265
|
+
self.std = std
|
|
266
|
+
self.trunc_a = trunc_a
|
|
267
|
+
self.trunc_b = trunc_b
|
|
268
|
+
|
|
269
|
+
def apply(self, model, override=None, verbose=False):
|
|
270
|
+
'''将初始化策略应用于模型的所有子模块。
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
model (nn.Module): 要初始化的模型。
|
|
274
|
+
override (dict, optional): 针对特定层名称的覆盖配置。
|
|
275
|
+
格式: {'regex_pattern': {'method': '...', ...}}
|
|
276
|
+
verbose (bool): 是否打印初始化信息。
|
|
277
|
+
'''
|
|
278
|
+
init_info = []
|
|
279
|
+
|
|
280
|
+
for name, module in model.named_modules():
|
|
281
|
+
current_config = {}
|
|
282
|
+
if override:
|
|
283
|
+
for pattern, config in override.items():
|
|
284
|
+
if re.search(pattern, name):
|
|
285
|
+
current_config = config
|
|
286
|
+
break
|
|
287
|
+
|
|
288
|
+
method = current_config.get('method', self.method)
|
|
289
|
+
distribution = current_config.get('distribution', self.distribution)
|
|
290
|
+
mode = current_config.get('mode', self.mode)
|
|
291
|
+
nonlinearity = current_config.get('nonlinearity', self.nonlinearity)
|
|
292
|
+
std = current_config.get('std', self.std)
|
|
293
|
+
|
|
294
|
+
if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)):
|
|
295
|
+
init_weights(
|
|
296
|
+
module,
|
|
297
|
+
method=method,
|
|
298
|
+
distribution=distribution,
|
|
299
|
+
bias=self.init_bias,
|
|
300
|
+
mode=mode,
|
|
301
|
+
nonlinearity=nonlinearity,
|
|
302
|
+
std=std,
|
|
303
|
+
trunc_a=self.trunc_a,
|
|
304
|
+
trunc_b=self.trunc_b
|
|
305
|
+
)
|
|
306
|
+
info = f'{method} ({distribution})'
|
|
307
|
+
if method == 'kaiming':
|
|
308
|
+
info += f', mode={mode}, nonlin={nonlinearity}'
|
|
309
|
+
elif method == 'normal':
|
|
310
|
+
info += f', std={std}'
|
|
311
|
+
init_info.append((name, module.__class__.__name__, info))
|
|
312
|
+
|
|
313
|
+
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm, nn.BatchNorm1d)):
|
|
314
|
+
init_layer_norm(
|
|
315
|
+
module, weight=self.init_norm_weight, bias=self.init_norm_bias)
|
|
316
|
+
init_info.append((name, module.__class__.__name__, 'Norm (1/0)'))
|
|
317
|
+
|
|
318
|
+
elif isinstance(module, nn.Embedding):
|
|
319
|
+
init_embedding(
|
|
320
|
+
module,
|
|
321
|
+
init_method=method if method in ['normal', 'trunc_normal', 'uniform'] else 'normal',
|
|
322
|
+
std=std
|
|
323
|
+
)
|
|
324
|
+
init_info.append((name, 'Embedding', f'{method} (std={std})'))
|
|
325
|
+
|
|
326
|
+
if verbose:
|
|
327
|
+
_print_init_info(init_info)
|
|
328
|
+
|
|
329
|
+
def _print_init_info(init_info):
|
|
330
|
+
'''打印初始化信息的辅助函数,支持 rich 美化。
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
init_info (list): 包含 (layer_name, module_type, details) 元组的列表。
|
|
334
|
+
'''
|
|
335
|
+
if not init_info:
|
|
336
|
+
return
|
|
337
|
+
|
|
338
|
+
if RICH_AVAILABLE:
|
|
339
|
+
console = Console()
|
|
340
|
+
table = Table(title="Weight Initialization Report", show_header=True, header_style="bold magenta")
|
|
341
|
+
table.add_column("Layer Name", style="cyan")
|
|
342
|
+
table.add_column("Module Type", style="green")
|
|
343
|
+
table.add_column("Initialization Details", style="yellow")
|
|
344
|
+
|
|
345
|
+
for name, type_name, details in init_info:
|
|
346
|
+
table.add_row(name, type_name, details)
|
|
347
|
+
|
|
348
|
+
console.print(table)
|
|
349
|
+
else:
|
|
350
|
+
print(f"{'Layer Name':<40} | {'Module Type':<20} | {'Initialization Details'}")
|
|
351
|
+
print("-" * 90)
|
|
352
|
+
for name, type_name, details in init_info:
|
|
353
|
+
print(f"{name:<40} | {type_name:<20} | {details}")
|
|
354
|
+
|
|
355
|
+
def initialize_weights(model, method='kaiming', override=None, verbose=False, **kwargs):
|
|
356
|
+
'''初始化模型权重的便捷函数。
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
model (nn.Module): 要初始化的模型。
|
|
360
|
+
method (str): 初始化方法。
|
|
361
|
+
override (dict, optional): 针对特定层名称的覆盖配置。
|
|
362
|
+
verbose (bool): 是否打印初始化信息。
|
|
363
|
+
**kwargs: 传递给 WeightInitializer 的其他参数。
|
|
364
|
+
'''
|
|
365
|
+
initializer = WeightInitializer(method=method, **kwargs)
|
|
366
|
+
initializer.apply(model, override=override, verbose=verbose)
|
|
367
|
+
|
|
368
|
+
class AutoInitializer:
|
|
369
|
+
'''自动初始化器,通过分析模型结构统计信息来应用最优初始化策略。
|
|
370
|
+
|
|
371
|
+
该类会自动探测模型的深度、激活函数分布以及是否包含 Transformer 结构,
|
|
372
|
+
并据此推荐合适的初始化方法(如 Kaiming, Xavier, 或带残差缩放的正态分布)。
|
|
373
|
+
|
|
374
|
+
Attributes:
|
|
375
|
+
model (nn.Module): 需要初始化的模型。
|
|
376
|
+
stats (dict): 模型分析统计信息。
|
|
377
|
+
'''
|
|
378
|
+
|
|
379
|
+
def __init__(self, model):
|
|
380
|
+
'''初始化 AutoInitializer。
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
model (nn.Module): 需要初始化的模型。
|
|
384
|
+
'''
|
|
385
|
+
self.model = model
|
|
386
|
+
self.stats = self._analyze_model()
|
|
387
|
+
|
|
388
|
+
def _analyze_model(self):
|
|
389
|
+
'''分析模型结构,收集统计信息。
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
dict: 包含深度、激活函数分布、层类型分布等信息的字典。
|
|
393
|
+
'''
|
|
394
|
+
stats = {
|
|
395
|
+
'depth': 0,
|
|
396
|
+
'activations': {},
|
|
397
|
+
'layer_types': {},
|
|
398
|
+
'transformer_detected': False
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
# 简单的深度估计:计算包含参数的层数
|
|
402
|
+
param_layers = [m for m in self.model.modules() if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv1d))]
|
|
403
|
+
stats['depth'] = len(param_layers)
|
|
404
|
+
|
|
405
|
+
# 激活函数探测
|
|
406
|
+
for m in self.model.modules():
|
|
407
|
+
name = m.__class__.__name__
|
|
408
|
+
if 'ReLU' in name:
|
|
409
|
+
stats['activations']['relu'] = stats['activations'].get('relu', 0) + 1
|
|
410
|
+
elif 'GELU' in name:
|
|
411
|
+
stats['activations']['gelu'] = stats['activations'].get('gelu', 0) + 1
|
|
412
|
+
elif 'Tanh' in name:
|
|
413
|
+
stats['activations']['tanh'] = stats['activations'].get('tanh', 0) + 1
|
|
414
|
+
elif 'Sigmoid' in name:
|
|
415
|
+
stats['activations']['sigmoid'] = stats['activations'].get('sigmoid', 0) + 1
|
|
416
|
+
|
|
417
|
+
# Transformer 检测
|
|
418
|
+
if 'Attention' in name or 'Transformer' in name:
|
|
419
|
+
stats['transformer_detected'] = True
|
|
420
|
+
|
|
421
|
+
stats['layer_types'][name] = stats['layer_types'].get(name, 0) + 1
|
|
422
|
+
|
|
423
|
+
return stats
|
|
424
|
+
|
|
425
|
+
def recommend_config(self):
|
|
426
|
+
'''基于统计信息推荐初始化配置。
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
tuple: (method, nonlinearity, config)
|
|
430
|
+
- method (str): 推荐的主要初始化方法。
|
|
431
|
+
- nonlinearity (str): 推荐的非线性函数参数。
|
|
432
|
+
- config (dict): 针对特定层的覆盖配置。
|
|
433
|
+
'''
|
|
434
|
+
config = {}
|
|
435
|
+
method = 'kaiming' # 默认
|
|
436
|
+
nonlinearity = 'relu'
|
|
437
|
+
|
|
438
|
+
# 1. 确定主要激活函数
|
|
439
|
+
acts = self.stats['activations']
|
|
440
|
+
if acts:
|
|
441
|
+
main_act = max(acts, key=acts.get)
|
|
442
|
+
if main_act == 'relu':
|
|
443
|
+
method = 'kaiming'
|
|
444
|
+
nonlinearity = 'relu'
|
|
445
|
+
elif main_act == 'gelu':
|
|
446
|
+
method = 'trunc_normal' # GELU 通常配合正态分布
|
|
447
|
+
elif main_act in ['tanh', 'sigmoid']:
|
|
448
|
+
method = 'xavier'
|
|
449
|
+
|
|
450
|
+
# 2. Transformer 特殊处理
|
|
451
|
+
if self.stats['transformer_detected']:
|
|
452
|
+
# 对于 Transformer,通常使用正态分布
|
|
453
|
+
method = 'normal'
|
|
454
|
+
# 残差缩放配置
|
|
455
|
+
n_layers = max(1, self.stats['depth'] // 4) # 粗略估计 Block 数量
|
|
456
|
+
scale = 1.0 / math.sqrt(2.0 * n_layers)
|
|
457
|
+
|
|
458
|
+
# 针对投影层的覆盖配置
|
|
459
|
+
config['.*linear_out.*'] = {'method': 'normal', 'std': 0.02 * scale}
|
|
460
|
+
config['.*fc2.*'] = {'method': 'normal', 'std': 0.02 * scale}
|
|
461
|
+
config['.*c_proj.*'] = {'method': 'normal', 'std': 0.02 * scale}
|
|
462
|
+
|
|
463
|
+
return method, nonlinearity, config
|
|
464
|
+
|
|
465
|
+
def apply(self, verbose=True):
|
|
466
|
+
'''应用自动生成的初始化策略。
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
verbose (bool): 是否打印分析报告和初始化详情。
|
|
470
|
+
'''
|
|
471
|
+
method, nonlinearity, override = self.recommend_config()
|
|
472
|
+
|
|
473
|
+
if verbose:
|
|
474
|
+
if RICH_AVAILABLE:
|
|
475
|
+
console = Console()
|
|
476
|
+
console.print(f"[bold cyan]Auto Initialization Analysis:[/bold cyan]")
|
|
477
|
+
console.print(f" Depth: {self.stats['depth']}")
|
|
478
|
+
console.print(f" Activations: {self.stats['activations']}")
|
|
479
|
+
console.print(f" Transformer Detected: {self.stats['transformer_detected']}")
|
|
480
|
+
console.print(f"[bold green]Recommended Strategy:[/bold green] {method} (nonlin={nonlinearity})")
|
|
481
|
+
else:
|
|
482
|
+
print(f"Auto Init: Depth={self.stats['depth']}, Acts={self.stats['activations']}")
|
|
483
|
+
print(f"Strategy: {method}, {nonlinearity}")
|
|
484
|
+
|
|
485
|
+
initialize_weights(
|
|
486
|
+
self.model,
|
|
487
|
+
method=method,
|
|
488
|
+
nonlinearity=nonlinearity,
|
|
489
|
+
override=override,
|
|
490
|
+
verbose=verbose
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
def auto_initialize(model, verbose=True):
|
|
494
|
+
'''自动初始化的便捷入口函数。
|
|
495
|
+
|
|
496
|
+
Args:
|
|
497
|
+
model (nn.Module): 需要初始化的模型。
|
|
498
|
+
verbose (bool): 是否打印初始化信息。
|
|
499
|
+
'''
|
|
500
|
+
initializer = AutoInitializer(model)
|
|
501
|
+
initializer.apply(verbose=verbose)
|
orbit/utils/layer_io.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
def get_module_by_name(model: nn.Module, name: str) -> nn.Module:
|
|
6
|
+
'''通过名称获取模型的子模块。
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
model (nn.Module): 目标模型。
|
|
10
|
+
name (str): 子模块名称,例如 'backbone.layer1'。
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
nn.Module: 找到的子模块。
|
|
14
|
+
|
|
15
|
+
Raises:
|
|
16
|
+
AttributeError: 如果找不到指定的模块名称。
|
|
17
|
+
'''
|
|
18
|
+
names = name.split('.')
|
|
19
|
+
module = model
|
|
20
|
+
for n in names:
|
|
21
|
+
if not hasattr(module, n):
|
|
22
|
+
raise AttributeError(f"Module '{type(module).__name__}' has no attribute '{n}'")
|
|
23
|
+
module = getattr(module, n)
|
|
24
|
+
return module
|
|
25
|
+
|
|
26
|
+
def save_layer_weights(model: nn.Module, layer_name: str, file_path: str) -> None:
|
|
27
|
+
'''保存模型指定层的权重到文件。
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
model (nn.Module): 目标模型。
|
|
31
|
+
layer_name (str): 要保存权重的层名称。
|
|
32
|
+
file_path (str): 保存路径。
|
|
33
|
+
'''
|
|
34
|
+
module = get_module_by_name(model, layer_name)
|
|
35
|
+
torch.save(module.state_dict(), file_path)
|
|
36
|
+
|
|
37
|
+
def load_layer_weights(
|
|
38
|
+
model: nn.Module,
|
|
39
|
+
layer_name: str,
|
|
40
|
+
file_path: str,
|
|
41
|
+
strict: bool = True,
|
|
42
|
+
map_location: Union[str, torch.device] = 'cpu'
|
|
43
|
+
) -> None:
|
|
44
|
+
'''从文件加载权重到模型的指定层。
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
model (nn.Module): 目标模型。
|
|
48
|
+
layer_name (str): 要加载权重的层名称。
|
|
49
|
+
file_path (str): 权重文件路径。
|
|
50
|
+
strict (bool): 是否严格匹配键值。默认为 True。
|
|
51
|
+
map_location (str or torch.device): 加载位置。默认为 'cpu'。
|
|
52
|
+
'''
|
|
53
|
+
state_dict = torch.load(file_path, map_location=map_location)
|
|
54
|
+
module = get_module_by_name(model, layer_name)
|
|
55
|
+
module.load_state_dict(state_dict, strict=strict)
|