orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -1
- orbit/callback.py +4 -3
- orbit/dataset/__init__.py +1 -0
- orbit/dataset/cogn.py +138 -0
- orbit/dataset/data/cogn_en.jsonl +45 -0
- orbit/dataset/data/cogn_zh.jsonl +113 -0
- orbit/engine.py +210 -146
- orbit/kit/__init__.py +2 -0
- orbit/kit/interface.py +154 -0
- orbit/kit/wrapper.py +157 -0
- orbit/model/__init__.py +5 -0
- orbit/model/base.py +125 -0
- orbit/model/block/__init__.py +34 -0
- orbit/model/block/attention.py +265 -0
- orbit/model/block/bio.py +537 -0
- orbit/model/block/codebook.py +122 -0
- orbit/model/block/conv.py +505 -0
- orbit/model/block/embedding.py +252 -0
- orbit/model/block/film.py +176 -0
- orbit/model/block/fusion.py +335 -0
- orbit/model/block/gate.py +334 -0
- orbit/model/block/lora.py +776 -0
- orbit/model/block/mlp.py +68 -0
- orbit/model/block/moe.py +94 -0
- orbit/model/block/tcn.py +99 -0
- orbit/model/config.py +62 -0
- orbit/model/kit/__init__.py +6 -0
- orbit/model/kit/discriminator.py +46 -0
- orbit/model/kit/losses.py +193 -0
- orbit/model/motif/__init__.py +0 -0
- orbit/model/motif/vision/__init__.py +0 -0
- orbit/model/motif/vision/v1.py +645 -0
- orbit/model/registry.py +53 -0
- orbit/optim/__init__.py +2 -2
- orbit/optim/sam.py +10 -3
- orbit/plugin/__init__.py +12 -8
- orbit/plugin/board.py +1 -2
- orbit/plugin/checkpoint.py +137 -62
- orbit/plugin/classification.py +2 -2
- orbit/plugin/display_model.py +1 -2
- orbit/plugin/early_stopping.py +1 -2
- orbit/plugin/ema.py +1 -2
- orbit/plugin/gradient_accumulation.py +1 -2
- orbit/plugin/lora.py +346 -0
- orbit/plugin/memory_estimator.py +1 -2
- orbit/plugin/warmup.py +1 -2
- orbit/utils/__init__.py +24 -1
- orbit/utils/cuda.py +10 -0
- orbit/utils/freeze.py +61 -17
- orbit/utils/image.py +164 -0
- orbit/utils/initialization.py +184 -94
- orbit/utils/layer_io.py +66 -7
- orbit/utils/lora.py +480 -0
- orbit/utils/moe.py +55 -0
- orbit/utils/seed.py +3 -19
- orbit/utils/sft.py +93 -0
- orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
- orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
- orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
orbit/utils/image.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class PatchOutput:
|
|
8
|
+
output: torch.Tensor
|
|
9
|
+
mask: torch.Tensor
|
|
10
|
+
patch_size: Tuple[int, int]
|
|
11
|
+
num_patches: Tuple[int, int]
|
|
12
|
+
|
|
13
|
+
def pad_to_patch_size(image: torch.Tensor, patch_size: Tuple[int, int]) -> PatchOutput:
|
|
14
|
+
'''对图像进行填充以适配补丁大小,不进行分割。
|
|
15
|
+
|
|
16
|
+
此函数接收形状为 [..., channels, width, height] 的图像张量,
|
|
17
|
+
并在右侧和底部进行零填充,使得填充后的尺寸能被 patch_size 整除。
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
image (torch.Tensor): 输入图像张量,形状为 [..., channels, w, h]。
|
|
21
|
+
最后两个维度被视为空间维度(宽度,高度)。
|
|
22
|
+
patch_size (Tuple[int, int]): 表示补丁大小的元组 (a, b),其中 'a' 对应于
|
|
23
|
+
宽度维度,'b' 对应于高度维度。
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
PatchOutput: 包含以下字段的数据类:
|
|
27
|
+
- output (torch.Tensor): 填充后的图像张量,形状为 [..., channels, w_padded, h_padded]。
|
|
28
|
+
- mask (torch.Tensor): 形状为 [..., 1, w_padded, h_padded] 的掩码张量,
|
|
29
|
+
有效区域为 1,填充区域为 0。
|
|
30
|
+
- patch_size (Tuple[int, int]): 输入的补丁大小 (a, b)。
|
|
31
|
+
- num_patches (Tuple[int, int]): 宽度和高度方向的补丁数量 (num_w, num_h)。
|
|
32
|
+
|
|
33
|
+
Raises:
|
|
34
|
+
ValueError: 如果输入图像维度少于 3。
|
|
35
|
+
'''
|
|
36
|
+
if image.ndim < 3:
|
|
37
|
+
raise ValueError(f'Input image must have at least 3 dimensions, got {image.ndim}')
|
|
38
|
+
|
|
39
|
+
w, h = image.shape[-2], image.shape[-1]
|
|
40
|
+
a, b = patch_size
|
|
41
|
+
|
|
42
|
+
pad_w = (a - w % a) % a
|
|
43
|
+
pad_h = (b - h % b) % b
|
|
44
|
+
|
|
45
|
+
image_padded = F.pad(image, (0, pad_h, 0, pad_w))
|
|
46
|
+
|
|
47
|
+
w_padded = w + pad_w
|
|
48
|
+
h_padded = h + pad_h
|
|
49
|
+
|
|
50
|
+
num_w = w_padded // a
|
|
51
|
+
num_h = h_padded // b
|
|
52
|
+
|
|
53
|
+
mask = torch.ones((*image.shape[:-3], 1, w, h), dtype=image.dtype, device=image.device)
|
|
54
|
+
mask_padded = F.pad(mask, (0, pad_h, 0, pad_w), value=0)
|
|
55
|
+
|
|
56
|
+
return PatchOutput(
|
|
57
|
+
output=image_padded,
|
|
58
|
+
mask=mask_padded,
|
|
59
|
+
patch_size=patch_size,
|
|
60
|
+
num_patches=(num_w, num_h)
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def split_to_patches(image: torch.Tensor, patch_size: Tuple[int, int]) -> PatchOutput:
|
|
64
|
+
'''将图像张量分割成多个子图像,支持自动填充。
|
|
65
|
+
|
|
66
|
+
此函数接收形状为 [..., channels, width, height] 的图像张量,
|
|
67
|
+
并将其划分为形状为 [channels, patch_width, patch_height] 的补丁。
|
|
68
|
+
如果图像尺寸不能被 patch_size 整除,则在右侧和底部进行零填充。
|
|
69
|
+
结果张量的形状为 [..., num_patches_total, channels, patch_width, patch_height]。
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
image (torch.Tensor): 输入图像张量,形状为 [..., channels, w, h]。
|
|
73
|
+
最后两个维度被视为空间维度(宽度,高度)。
|
|
74
|
+
patch_size (Tuple[int, int]): 表示补丁大小的元组 (a, b),其中 'a' 对应于
|
|
75
|
+
宽度维度,'b' 对应于高度维度。
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
PatchOutput: 包含以下字段的数据类:
|
|
79
|
+
- output (torch.Tensor): 形状为 [..., num_w * num_h, channels, a, b] 的补丁张量。
|
|
80
|
+
- mask (torch.Tensor): 形状为 [..., num_w * num_h, 1, a, b] 的掩码张量,
|
|
81
|
+
有效区域为 1,填充区域为 0。
|
|
82
|
+
- patch_size (Tuple[int, int]): 输入的补丁大小 (a, b)。
|
|
83
|
+
- num_patches (Tuple[int, int]): 宽度和高度方向的补丁数量 (num_w, num_h)。
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
ValueError: 如果输入图像维度少于 3。
|
|
87
|
+
'''
|
|
88
|
+
padded_output = pad_to_patch_size(image, patch_size)
|
|
89
|
+
image_padded = padded_output.output
|
|
90
|
+
mask_padded = padded_output.mask
|
|
91
|
+
num_w, num_h = padded_output.num_patches
|
|
92
|
+
a, b = patch_size
|
|
93
|
+
|
|
94
|
+
def split(x, nw, nh):
|
|
95
|
+
# x shape: [..., C, W, H]
|
|
96
|
+
reshaped = x.view(*x.shape[:-2], nw, a, nh, b)
|
|
97
|
+
# permute to [..., nw, nh, C, a, b]
|
|
98
|
+
permuted = reshaped.permute(
|
|
99
|
+
*range(reshaped.ndim - 5),
|
|
100
|
+
reshaped.ndim - 4, # nw
|
|
101
|
+
reshaped.ndim - 2, # nh
|
|
102
|
+
reshaped.ndim - 5, # C
|
|
103
|
+
reshaped.ndim - 3, # a
|
|
104
|
+
reshaped.ndim - 1 # b
|
|
105
|
+
)
|
|
106
|
+
return permuted.reshape(*x.shape[:-3], nw * nh, x.shape[-3], a, b)
|
|
107
|
+
|
|
108
|
+
output_patches = split(image_padded, num_w, num_h)
|
|
109
|
+
mask_patches = split(mask_padded, num_w, num_h)
|
|
110
|
+
|
|
111
|
+
return PatchOutput(
|
|
112
|
+
output=output_patches,
|
|
113
|
+
mask=mask_patches,
|
|
114
|
+
patch_size=patch_size,
|
|
115
|
+
num_patches=(num_w, num_h)
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def reconstruct_from_patches(patches: torch.Tensor, num_patches: Tuple[int, int], mask: torch.Tensor = None) -> torch.Tensor:
|
|
119
|
+
'''从补丁重建图像。
|
|
120
|
+
|
|
121
|
+
此函数是 split_to_patches 的逆操作。它将补丁张量重新组合成原始图像。
|
|
122
|
+
如果提供了 mask,则会根据 mask 去除 padding。
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
patches (torch.Tensor): 形状为 [..., num_patches, channels, patch_width, patch_height] 的补丁张量。
|
|
126
|
+
num_patches (Tuple[int, int]): 宽度和高度方向的补丁数量 (num_w, num_h)。
|
|
127
|
+
mask (torch.Tensor, optional): 用于去除 padding 的掩码,形状与 patches 相同但通道数为 1。
|
|
128
|
+
如果提供,将根据掩码裁剪重建后的图像以去除 padding。
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
torch.Tensor: 重建后的图像张量,形状为 [..., channels, width, height]。
|
|
132
|
+
'''
|
|
133
|
+
nw, nh = num_patches
|
|
134
|
+
if patches.shape[-4] != nw * nh:
|
|
135
|
+
raise ValueError(f"Number of patches in tensor ({patches.shape[-4]}) does not match num_patches argument ({nw} * {nh} = {nw*nh})")
|
|
136
|
+
|
|
137
|
+
a, b = patches.shape[-2], patches.shape[-1]
|
|
138
|
+
|
|
139
|
+
def unsplit(x):
|
|
140
|
+
# x: [..., nw*nh, C, a, b]
|
|
141
|
+
reshaped = x.view(*x.shape[:-4], nw, nh, *x.shape[-3:])
|
|
142
|
+
# permute to [..., C, nw, a, nh, b]
|
|
143
|
+
permuted = reshaped.permute(
|
|
144
|
+
*range(reshaped.ndim - 5),
|
|
145
|
+
reshaped.ndim - 3, # C
|
|
146
|
+
reshaped.ndim - 5, # nw
|
|
147
|
+
reshaped.ndim - 2, # a
|
|
148
|
+
reshaped.ndim - 4, # nh
|
|
149
|
+
reshaped.ndim - 1 # b
|
|
150
|
+
)
|
|
151
|
+
return permuted.reshape(*x.shape[:-4], x.shape[-3], nw * a, nh * b)
|
|
152
|
+
|
|
153
|
+
reconstructed = unsplit(patches)
|
|
154
|
+
|
|
155
|
+
if mask is not None:
|
|
156
|
+
reconstructed_mask = unsplit(mask)
|
|
157
|
+
m = reconstructed_mask.view(-1, reconstructed_mask.shape[-2], reconstructed_mask.shape[-1])
|
|
158
|
+
|
|
159
|
+
valid_h = (m[0, 0, :] > 0.5).sum().item()
|
|
160
|
+
valid_w = (m[0, :, 0] > 0.5).sum().item()
|
|
161
|
+
|
|
162
|
+
reconstructed = reconstructed[..., :int(valid_w), :int(valid_h)]
|
|
163
|
+
|
|
164
|
+
return reconstructed
|
orbit/utils/initialization.py
CHANGED
|
@@ -4,13 +4,8 @@ import re
|
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn as nn
|
|
6
6
|
from torch.nn.init import _calculate_fan_in_and_fan_out
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
from rich.console import Console
|
|
10
|
-
from rich.table import Table
|
|
11
|
-
RICH_AVAILABLE = True
|
|
12
|
-
except ImportError:
|
|
13
|
-
RICH_AVAILABLE = False
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
from rich.table import Table
|
|
14
9
|
|
|
15
10
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
|
16
11
|
'''截断正态分布初始化的辅助函数,在无梯度模式下运行。
|
|
@@ -68,11 +63,63 @@ def constant_init(module, val, bias=0):
|
|
|
68
63
|
val (float): 权重的常数值。
|
|
69
64
|
bias (float): 偏置的常数值。
|
|
70
65
|
'''
|
|
66
|
+
if isinstance(module, (nn.Parameter, torch.Tensor)):
|
|
67
|
+
nn.init.constant_(module, val)
|
|
68
|
+
return
|
|
69
|
+
|
|
71
70
|
if hasattr(module, 'weight') and module.weight is not None:
|
|
72
71
|
nn.init.constant_(module.weight, val)
|
|
73
72
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
74
73
|
nn.init.constant_(module.bias, bias)
|
|
75
74
|
|
|
75
|
+
def _init_tensor_impl(tensor, method, distribution, a, mode, nonlinearity, gain, std, trunc_a, trunc_b):
|
|
76
|
+
'''内部函数:对单个张量应用初始化方法。'''
|
|
77
|
+
info = ""
|
|
78
|
+
if method == 'kaiming':
|
|
79
|
+
if distribution == 'uniform':
|
|
80
|
+
nn.init.kaiming_uniform_(
|
|
81
|
+
tensor, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
82
|
+
else:
|
|
83
|
+
nn.init.kaiming_normal_(
|
|
84
|
+
tensor, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
85
|
+
info = f'Kaiming ({distribution}), mode={mode}, nonlin={nonlinearity}'
|
|
86
|
+
|
|
87
|
+
elif method == 'xavier':
|
|
88
|
+
if distribution == 'uniform':
|
|
89
|
+
nn.init.xavier_uniform_(tensor, gain=gain)
|
|
90
|
+
else:
|
|
91
|
+
nn.init.xavier_normal_(tensor, gain=gain)
|
|
92
|
+
info = f'Xavier ({distribution}), gain={gain}'
|
|
93
|
+
|
|
94
|
+
elif method == 'c2_xavier':
|
|
95
|
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
|
96
|
+
c2_std = math.sqrt(1.0 / float(fan_in))
|
|
97
|
+
nn.init.normal_(tensor, mean=0.0, std=c2_std)
|
|
98
|
+
info = f'C2 Xavier (Normal), std={c2_std:.4f}'
|
|
99
|
+
|
|
100
|
+
elif method == 'orthogonal':
|
|
101
|
+
nn.init.orthogonal_(tensor, gain=gain)
|
|
102
|
+
info = f'Orthogonal, gain={gain}'
|
|
103
|
+
|
|
104
|
+
elif method == 'trunc_normal':
|
|
105
|
+
trunc_normal_(
|
|
106
|
+
tensor, mean=0., std=std, a=trunc_a, b=trunc_b)
|
|
107
|
+
info = f'Trunc Normal, std={std}, a={trunc_a}, b={trunc_b}'
|
|
108
|
+
|
|
109
|
+
elif method == 'normal':
|
|
110
|
+
nn.init.normal_(tensor, mean=0., std=std)
|
|
111
|
+
info = f'Normal, std={std}'
|
|
112
|
+
|
|
113
|
+
elif method == 'constant':
|
|
114
|
+
nn.init.constant_(tensor, val=gain)
|
|
115
|
+
info = f'Constant, val={gain}'
|
|
116
|
+
|
|
117
|
+
else:
|
|
118
|
+
nn.init.xavier_uniform_(tensor, gain=gain)
|
|
119
|
+
info = f'Xavier (Uniform) [Default], gain={gain}'
|
|
120
|
+
|
|
121
|
+
return info
|
|
122
|
+
|
|
76
123
|
def init_weights(module, method='kaiming', distribution='normal', bias=0,
|
|
77
124
|
a=0, mode='fan_out', nonlinearity='relu',
|
|
78
125
|
gain=1,
|
|
@@ -98,45 +145,56 @@ def init_weights(module, method='kaiming', distribution='normal', bias=0,
|
|
|
98
145
|
std (float): Normal/Truncated Normal 的标准差。
|
|
99
146
|
trunc_a (float): Truncated Normal 的下界。
|
|
100
147
|
trunc_b (float): Truncated Normal 的上界。
|
|
101
|
-
'''
|
|
102
|
-
if hasattr(module, 'weight') and module.weight is not None:
|
|
103
|
-
if method == 'kaiming':
|
|
104
|
-
if distribution == 'uniform':
|
|
105
|
-
nn.init.kaiming_uniform_(
|
|
106
|
-
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
107
|
-
else:
|
|
108
|
-
nn.init.kaiming_normal_(
|
|
109
|
-
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
110
|
-
|
|
111
|
-
elif method == 'xavier':
|
|
112
|
-
if distribution == 'uniform':
|
|
113
|
-
nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
114
|
-
else:
|
|
115
|
-
nn.init.xavier_normal_(module.weight, gain=gain)
|
|
116
148
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
149
|
+
Returns:
|
|
150
|
+
str: 初始化详情字符串,如果未执行初始化则返回 None。
|
|
151
|
+
'''
|
|
152
|
+
# 1. 直接处理 Parameter/Tensor
|
|
153
|
+
if isinstance(module, (nn.Parameter, torch.Tensor)):
|
|
154
|
+
return _init_tensor_impl(module, method, distribution, a, mode, nonlinearity, gain, std, trunc_a, trunc_b)
|
|
155
|
+
|
|
156
|
+
# 2. 处理 Module
|
|
157
|
+
info_parts = []
|
|
158
|
+
handled_params = set()
|
|
159
|
+
|
|
160
|
+
def init_and_record(tensor, name, is_bias=False):
|
|
161
|
+
if id(tensor) in handled_params:
|
|
162
|
+
return
|
|
128
163
|
|
|
129
|
-
|
|
130
|
-
nn.init.
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
nn.init.constant_(module.weight, val=gain)
|
|
134
|
-
|
|
164
|
+
if is_bias:
|
|
165
|
+
nn.init.constant_(tensor, bias)
|
|
166
|
+
# 简化输出:如果是标准的 bias 且值为 0,可能不需要太详细,但为了清晰还是保留
|
|
167
|
+
info = f"bias={bias}" if name == 'bias' else f"{name}: Constant({bias})"
|
|
135
168
|
else:
|
|
136
|
-
|
|
169
|
+
info = _init_tensor_impl(tensor, method, distribution, a, mode, nonlinearity, gain, std, trunc_a, trunc_b)
|
|
170
|
+
if name != 'weight':
|
|
171
|
+
info = f"{name}: {info}"
|
|
172
|
+
|
|
173
|
+
info_parts.append(info)
|
|
174
|
+
handled_params.add(id(tensor))
|
|
175
|
+
|
|
176
|
+
# A. 优先处理标准属性 'weight'
|
|
177
|
+
if hasattr(module, 'weight') and module.weight is not None:
|
|
178
|
+
init_and_record(module.weight, 'weight', is_bias=False)
|
|
137
179
|
|
|
180
|
+
# B. 优先处理标准属性 'bias'
|
|
138
181
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
139
|
-
|
|
182
|
+
init_and_record(module.bias, 'bias', is_bias=True)
|
|
183
|
+
|
|
184
|
+
# C. 遍历所有注册参数 (处理自定义名称)
|
|
185
|
+
# recurse=False 确保只处理当前模块的直接参数
|
|
186
|
+
for name, param in module.named_parameters(recurse=False):
|
|
187
|
+
if id(param) in handled_params:
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
# 启发式规则:维度 < 2 视为偏置类参数,否则视为权重类参数
|
|
191
|
+
is_bias_like = param.ndim < 2
|
|
192
|
+
init_and_record(param, name, is_bias=is_bias_like)
|
|
193
|
+
|
|
194
|
+
if not info_parts:
|
|
195
|
+
return None
|
|
196
|
+
|
|
197
|
+
return ", ".join(info_parts)
|
|
140
198
|
|
|
141
199
|
def init_layer_norm(module, weight=1.0, bias=0.0):
|
|
142
200
|
'''初始化 LayerNorm 或 GroupNorm 模块。
|
|
@@ -145,11 +203,21 @@ def init_layer_norm(module, weight=1.0, bias=0.0):
|
|
|
145
203
|
module (nn.Module): 归一化模块。
|
|
146
204
|
weight (float): 权重的初始值 (gamma)。
|
|
147
205
|
bias (float): 偏置的初始值 (beta)。
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
str: 初始化详情字符串。
|
|
148
209
|
'''
|
|
210
|
+
initialized = False
|
|
149
211
|
if hasattr(module, 'weight') and module.weight is not None:
|
|
150
212
|
nn.init.constant_(module.weight, weight)
|
|
213
|
+
initialized = True
|
|
151
214
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
152
215
|
nn.init.constant_(module.bias, bias)
|
|
216
|
+
initialized = True
|
|
217
|
+
|
|
218
|
+
if initialized:
|
|
219
|
+
return f'Norm (w={weight}, b={bias})'
|
|
220
|
+
return None
|
|
153
221
|
|
|
154
222
|
def init_embedding(module, init_method='normal', std=0.02, a=0., b=1., padding_idx=None):
|
|
155
223
|
'''初始化 Embedding 层。
|
|
@@ -161,17 +229,32 @@ def init_embedding(module, init_method='normal', std=0.02, a=0., b=1., padding_i
|
|
|
161
229
|
a (float): 均匀分布的下界或截断正态分布的下界。
|
|
162
230
|
b (float): 均匀分布的上界或截断正态分布的上界。
|
|
163
231
|
padding_idx (int, optional): 如果指定,padding 索引的权重将被初始化为 0。
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
str: 初始化详情字符串。
|
|
164
235
|
'''
|
|
165
|
-
if hasattr(module, 'weight') and module.weight is not None:
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
236
|
+
if not (hasattr(module, 'weight') and module.weight is not None):
|
|
237
|
+
return None
|
|
238
|
+
|
|
239
|
+
info = ""
|
|
240
|
+
if init_method == 'normal':
|
|
241
|
+
nn.init.normal_(module.weight, mean=0., std=std)
|
|
242
|
+
info = f'Normal (std={std})'
|
|
243
|
+
elif init_method == 'trunc_normal':
|
|
244
|
+
trunc_normal_(module.weight, mean=0., std=std, a=a, b=b)
|
|
245
|
+
info = f'Trunc Normal (std={std}, [{a}, {b}])'
|
|
246
|
+
elif init_method == 'uniform':
|
|
247
|
+
nn.init.uniform_(module.weight, a=a, b=b)
|
|
248
|
+
info = f'Uniform ([{a}, {b}])'
|
|
249
|
+
else:
|
|
250
|
+
nn.init.normal_(module.weight, mean=0., std=std)
|
|
251
|
+
info = f'Normal (std={std})'
|
|
172
252
|
|
|
173
253
|
if padding_idx is not None:
|
|
174
254
|
module.weight.data[padding_idx].zero_()
|
|
255
|
+
info += f', pad_idx={padding_idx}'
|
|
256
|
+
|
|
257
|
+
return info
|
|
175
258
|
|
|
176
259
|
def init_weights_transformer(model, n_layer=None, initializer_range=0.02,
|
|
177
260
|
residual_proj_names=('linear_out', 'fc2', 'c_proj'),
|
|
@@ -277,6 +360,25 @@ class WeightInitializer:
|
|
|
277
360
|
'''
|
|
278
361
|
init_info = []
|
|
279
362
|
|
|
363
|
+
# 处理单个 Parameter/Tensor
|
|
364
|
+
if isinstance(model, (nn.Parameter, torch.Tensor)):
|
|
365
|
+
info = init_weights(
|
|
366
|
+
model,
|
|
367
|
+
method=self.method,
|
|
368
|
+
distribution=self.distribution,
|
|
369
|
+
bias=self.init_bias,
|
|
370
|
+
mode=self.mode,
|
|
371
|
+
nonlinearity=self.nonlinearity,
|
|
372
|
+
std=self.std,
|
|
373
|
+
trunc_a=self.trunc_a,
|
|
374
|
+
trunc_b=self.trunc_b
|
|
375
|
+
)
|
|
376
|
+
if info:
|
|
377
|
+
init_info.append(('Parameter/Tensor', type(model).__name__, info))
|
|
378
|
+
if verbose:
|
|
379
|
+
_print_init_info(init_info)
|
|
380
|
+
return
|
|
381
|
+
|
|
280
382
|
for name, module in model.named_modules():
|
|
281
383
|
current_config = {}
|
|
282
384
|
if override:
|
|
@@ -291,8 +393,19 @@ class WeightInitializer:
|
|
|
291
393
|
nonlinearity = current_config.get('nonlinearity', self.nonlinearity)
|
|
292
394
|
std = current_config.get('std', self.std)
|
|
293
395
|
|
|
294
|
-
|
|
295
|
-
|
|
396
|
+
info = None
|
|
397
|
+
|
|
398
|
+
if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm, nn.BatchNorm1d)):
|
|
399
|
+
info = init_layer_norm(
|
|
400
|
+
module, weight=self.init_norm_weight, bias=self.init_norm_bias)
|
|
401
|
+
|
|
402
|
+
elif isinstance(module, nn.Embedding):
|
|
403
|
+
emb_method = method if method in ['normal', 'trunc_normal', 'uniform'] else 'normal'
|
|
404
|
+
info = init_embedding(module, init_method=emb_method, std=std)
|
|
405
|
+
|
|
406
|
+
else:
|
|
407
|
+
# 尝试通用初始化 (Linear, Conv, 或其他带 weight/bias 的层)
|
|
408
|
+
info = init_weights(
|
|
296
409
|
module,
|
|
297
410
|
method=method,
|
|
298
411
|
distribution=distribution,
|
|
@@ -303,31 +416,15 @@ class WeightInitializer:
|
|
|
303
416
|
trunc_a=self.trunc_a,
|
|
304
417
|
trunc_b=self.trunc_b
|
|
305
418
|
)
|
|
306
|
-
info = f'{method} ({distribution})'
|
|
307
|
-
if method == 'kaiming':
|
|
308
|
-
info += f', mode={mode}, nonlin={nonlinearity}'
|
|
309
|
-
elif method == 'normal':
|
|
310
|
-
info += f', std={std}'
|
|
311
|
-
init_info.append((name, module.__class__.__name__, info))
|
|
312
|
-
|
|
313
|
-
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm, nn.BatchNorm1d)):
|
|
314
|
-
init_layer_norm(
|
|
315
|
-
module, weight=self.init_norm_weight, bias=self.init_norm_bias)
|
|
316
|
-
init_info.append((name, module.__class__.__name__, 'Norm (1/0)'))
|
|
317
419
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
module,
|
|
321
|
-
init_method=method if method in ['normal', 'trunc_normal', 'uniform'] else 'normal',
|
|
322
|
-
std=std
|
|
323
|
-
)
|
|
324
|
-
init_info.append((name, 'Embedding', f'{method} (std={std})'))
|
|
420
|
+
if info:
|
|
421
|
+
init_info.append((name, module.__class__.__name__, info))
|
|
325
422
|
|
|
326
423
|
if verbose:
|
|
327
424
|
_print_init_info(init_info)
|
|
328
425
|
|
|
329
426
|
def _print_init_info(init_info):
|
|
330
|
-
'''
|
|
427
|
+
'''打印初始化信息的辅助函数,使用 rich 美化。
|
|
331
428
|
|
|
332
429
|
Args:
|
|
333
430
|
init_info (list): 包含 (layer_name, module_type, details) 元组的列表。
|
|
@@ -335,22 +432,16 @@ def _print_init_info(init_info):
|
|
|
335
432
|
if not init_info:
|
|
336
433
|
return
|
|
337
434
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
table.add_column("Initialization Details", style="yellow")
|
|
435
|
+
console = Console()
|
|
436
|
+
table = Table(title="Weight Initialization Report", show_header=True, header_style="bold magenta")
|
|
437
|
+
table.add_column("Layer Name", style="cyan")
|
|
438
|
+
table.add_column("Module Type", style="green")
|
|
439
|
+
table.add_column("Initialization Details", style="yellow")
|
|
344
440
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
else:
|
|
350
|
-
print(f"{'Layer Name':<40} | {'Module Type':<20} | {'Initialization Details'}")
|
|
351
|
-
print("-" * 90)
|
|
352
|
-
for name, type_name, details in init_info:
|
|
353
|
-
print(f"{name:<40} | {type_name:<20} | {details}")
|
|
441
|
+
for name, type_name, details in init_info:
|
|
442
|
+
table.add_row(str(name), str(type_name), str(details))
|
|
443
|
+
|
|
444
|
+
console.print(table)
|
|
354
445
|
|
|
355
446
|
def initialize_weights(model, method='kaiming', override=None, verbose=False, **kwargs):
|
|
356
447
|
'''初始化模型权重的便捷函数。
|
|
@@ -398,6 +489,9 @@ class AutoInitializer:
|
|
|
398
489
|
'transformer_detected': False
|
|
399
490
|
}
|
|
400
491
|
|
|
492
|
+
if isinstance(self.model, (nn.Parameter, torch.Tensor)):
|
|
493
|
+
return stats
|
|
494
|
+
|
|
401
495
|
# 简单的深度估计:计算包含参数的层数
|
|
402
496
|
param_layers = [m for m in self.model.modules() if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv1d))]
|
|
403
497
|
stats['depth'] = len(param_layers)
|
|
@@ -471,16 +565,12 @@ class AutoInitializer:
|
|
|
471
565
|
method, nonlinearity, override = self.recommend_config()
|
|
472
566
|
|
|
473
567
|
if verbose:
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
console.print(f"[bold green]Recommended Strategy:[/bold green] {method} (nonlin={nonlinearity})")
|
|
481
|
-
else:
|
|
482
|
-
print(f"Auto Init: Depth={self.stats['depth']}, Acts={self.stats['activations']}")
|
|
483
|
-
print(f"Strategy: {method}, {nonlinearity}")
|
|
568
|
+
console = Console()
|
|
569
|
+
console.print(f"[bold cyan]Auto Initialization Analysis:[/bold cyan]")
|
|
570
|
+
console.print(f" Depth: {self.stats['depth']}")
|
|
571
|
+
console.print(f" Activations: {self.stats['activations']}")
|
|
572
|
+
console.print(f" Transformer Detected: {self.stats['transformer_detected']}")
|
|
573
|
+
console.print(f"[bold green]Recommended Strategy:[/bold green] {method} (nonlin={nonlinearity})")
|
|
484
574
|
|
|
485
575
|
initialize_weights(
|
|
486
576
|
self.model,
|
orbit/utils/layer_io.py
CHANGED
|
@@ -2,7 +2,10 @@ import torch
|
|
|
2
2
|
import torch.nn as nn
|
|
3
3
|
from typing import Union
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
from safetensors.torch import save_file as safe_save_file
|
|
6
|
+
from safetensors.torch import load_file as safe_load_file
|
|
7
|
+
|
|
8
|
+
def get_model_by_name(model: nn.Module, name: str) -> nn.Module:
|
|
6
9
|
'''通过名称获取模型的子模块。
|
|
7
10
|
|
|
8
11
|
Args:
|
|
@@ -23,7 +26,7 @@ def get_module_by_name(model: nn.Module, name: str) -> nn.Module:
|
|
|
23
26
|
module = getattr(module, n)
|
|
24
27
|
return module
|
|
25
28
|
|
|
26
|
-
def
|
|
29
|
+
def save_layer(model: nn.Module, layer_name: str, file_path: str) -> None:
|
|
27
30
|
'''保存模型指定层的权重到文件。
|
|
28
31
|
|
|
29
32
|
Args:
|
|
@@ -31,10 +34,13 @@ def save_layer_weights(model: nn.Module, layer_name: str, file_path: str) -> Non
|
|
|
31
34
|
layer_name (str): 要保存权重的层名称。
|
|
32
35
|
file_path (str): 保存路径。
|
|
33
36
|
'''
|
|
34
|
-
module =
|
|
35
|
-
|
|
37
|
+
module = get_model_by_name(model, layer_name)
|
|
38
|
+
if file_path.endswith('.safetensors'):
|
|
39
|
+
safe_save_file(module.state_dict(), file_path)
|
|
40
|
+
else:
|
|
41
|
+
torch.save(module.state_dict(), file_path)
|
|
36
42
|
|
|
37
|
-
def
|
|
43
|
+
def load_layer(
|
|
38
44
|
model: nn.Module,
|
|
39
45
|
layer_name: str,
|
|
40
46
|
file_path: str,
|
|
@@ -50,6 +56,59 @@ def load_layer_weights(
|
|
|
50
56
|
strict (bool): 是否严格匹配键值。默认为 True。
|
|
51
57
|
map_location (str or torch.device): 加载位置。默认为 'cpu'。
|
|
52
58
|
'''
|
|
53
|
-
|
|
54
|
-
|
|
59
|
+
if file_path.endswith('.safetensors'):
|
|
60
|
+
state_dict = safe_load_file(file_path, device=str(map_location))
|
|
61
|
+
else:
|
|
62
|
+
state_dict = torch.load(file_path, map_location=map_location)
|
|
63
|
+
|
|
64
|
+
if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
|
|
65
|
+
state_dict = state_dict['model_state_dict']
|
|
66
|
+
|
|
67
|
+
module = get_model_by_name(model, layer_name)
|
|
68
|
+
|
|
69
|
+
prefix = layer_name + '.'
|
|
70
|
+
if any(k.startswith(prefix) for k in state_dict.keys()):
|
|
71
|
+
new_state_dict = {}
|
|
72
|
+
for k, v in state_dict.items():
|
|
73
|
+
if k.startswith(prefix):
|
|
74
|
+
new_key = k[len(prefix):]
|
|
75
|
+
new_state_dict[new_key] = v
|
|
76
|
+
state_dict = new_state_dict
|
|
77
|
+
|
|
55
78
|
module.load_state_dict(state_dict, strict=strict)
|
|
79
|
+
|
|
80
|
+
def save_model(model: nn.Module, file_path: str) -> None:
|
|
81
|
+
'''保存整个模型的权重到文件。
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
model (nn.Module): 目标模型。
|
|
85
|
+
file_path (str): 保存路径。
|
|
86
|
+
'''
|
|
87
|
+
if file_path.endswith('.safetensors'):
|
|
88
|
+
safe_save_file(model.state_dict(), file_path)
|
|
89
|
+
else:
|
|
90
|
+
torch.save(model.state_dict(), file_path)
|
|
91
|
+
|
|
92
|
+
def load_model(
|
|
93
|
+
model: nn.Module,
|
|
94
|
+
file_path: str,
|
|
95
|
+
strict: bool = True,
|
|
96
|
+
map_location: Union[str, torch.device] = 'cpu'
|
|
97
|
+
) -> None:
|
|
98
|
+
'''从文件加载权重到整个模型。
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
model (nn.Module): 目标模型。
|
|
102
|
+
file_path (str): 权重文件路径。
|
|
103
|
+
strict (bool): 是否严格匹配键值。默认为 True。
|
|
104
|
+
map_location (str or torch.device): 加载位置。默认为 'cpu'。
|
|
105
|
+
'''
|
|
106
|
+
if file_path.endswith('.safetensors'):
|
|
107
|
+
state_dict = safe_load_file(file_path, device=str(map_location))
|
|
108
|
+
else:
|
|
109
|
+
state_dict = torch.load(file_path, map_location=map_location)
|
|
110
|
+
|
|
111
|
+
if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
|
|
112
|
+
state_dict = state_dict['model_state_dict']
|
|
113
|
+
|
|
114
|
+
model.load_state_dict(state_dict, strict=strict)
|