orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. orbit/__init__.py +3 -1
  2. orbit/callback.py +4 -3
  3. orbit/dataset/__init__.py +1 -0
  4. orbit/dataset/cogn.py +138 -0
  5. orbit/dataset/data/cogn_en.jsonl +45 -0
  6. orbit/dataset/data/cogn_zh.jsonl +113 -0
  7. orbit/engine.py +210 -146
  8. orbit/kit/__init__.py +2 -0
  9. orbit/kit/interface.py +154 -0
  10. orbit/kit/wrapper.py +157 -0
  11. orbit/model/__init__.py +5 -0
  12. orbit/model/base.py +125 -0
  13. orbit/model/block/__init__.py +34 -0
  14. orbit/model/block/attention.py +265 -0
  15. orbit/model/block/bio.py +537 -0
  16. orbit/model/block/codebook.py +122 -0
  17. orbit/model/block/conv.py +505 -0
  18. orbit/model/block/embedding.py +252 -0
  19. orbit/model/block/film.py +176 -0
  20. orbit/model/block/fusion.py +335 -0
  21. orbit/model/block/gate.py +334 -0
  22. orbit/model/block/lora.py +776 -0
  23. orbit/model/block/mlp.py +68 -0
  24. orbit/model/block/moe.py +94 -0
  25. orbit/model/block/tcn.py +99 -0
  26. orbit/model/config.py +62 -0
  27. orbit/model/kit/__init__.py +6 -0
  28. orbit/model/kit/discriminator.py +46 -0
  29. orbit/model/kit/losses.py +193 -0
  30. orbit/model/motif/__init__.py +0 -0
  31. orbit/model/motif/vision/__init__.py +0 -0
  32. orbit/model/motif/vision/v1.py +645 -0
  33. orbit/model/registry.py +53 -0
  34. orbit/optim/__init__.py +2 -2
  35. orbit/optim/sam.py +10 -3
  36. orbit/plugin/__init__.py +12 -8
  37. orbit/plugin/board.py +1 -2
  38. orbit/plugin/checkpoint.py +137 -62
  39. orbit/plugin/classification.py +2 -2
  40. orbit/plugin/display_model.py +1 -2
  41. orbit/plugin/early_stopping.py +1 -2
  42. orbit/plugin/ema.py +1 -2
  43. orbit/plugin/gradient_accumulation.py +1 -2
  44. orbit/plugin/lora.py +346 -0
  45. orbit/plugin/memory_estimator.py +1 -2
  46. orbit/plugin/warmup.py +1 -2
  47. orbit/utils/__init__.py +24 -1
  48. orbit/utils/cuda.py +10 -0
  49. orbit/utils/freeze.py +61 -17
  50. orbit/utils/image.py +164 -0
  51. orbit/utils/initialization.py +184 -94
  52. orbit/utils/layer_io.py +66 -7
  53. orbit/utils/lora.py +480 -0
  54. orbit/utils/moe.py +55 -0
  55. orbit/utils/seed.py +3 -19
  56. orbit/utils/sft.py +93 -0
  57. orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
  58. orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
  59. orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
  60. orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
  61. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
  62. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
orbit/utils/image.py ADDED
@@ -0,0 +1,164 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Tuple
4
+ from dataclasses import dataclass
5
+
6
+ @dataclass
7
+ class PatchOutput:
8
+ output: torch.Tensor
9
+ mask: torch.Tensor
10
+ patch_size: Tuple[int, int]
11
+ num_patches: Tuple[int, int]
12
+
13
+ def pad_to_patch_size(image: torch.Tensor, patch_size: Tuple[int, int]) -> PatchOutput:
14
+ '''对图像进行填充以适配补丁大小,不进行分割。
15
+
16
+ 此函数接收形状为 [..., channels, width, height] 的图像张量,
17
+ 并在右侧和底部进行零填充,使得填充后的尺寸能被 patch_size 整除。
18
+
19
+ Args:
20
+ image (torch.Tensor): 输入图像张量,形状为 [..., channels, w, h]。
21
+ 最后两个维度被视为空间维度(宽度,高度)。
22
+ patch_size (Tuple[int, int]): 表示补丁大小的元组 (a, b),其中 'a' 对应于
23
+ 宽度维度,'b' 对应于高度维度。
24
+
25
+ Returns:
26
+ PatchOutput: 包含以下字段的数据类:
27
+ - output (torch.Tensor): 填充后的图像张量,形状为 [..., channels, w_padded, h_padded]。
28
+ - mask (torch.Tensor): 形状为 [..., 1, w_padded, h_padded] 的掩码张量,
29
+ 有效区域为 1,填充区域为 0。
30
+ - patch_size (Tuple[int, int]): 输入的补丁大小 (a, b)。
31
+ - num_patches (Tuple[int, int]): 宽度和高度方向的补丁数量 (num_w, num_h)。
32
+
33
+ Raises:
34
+ ValueError: 如果输入图像维度少于 3。
35
+ '''
36
+ if image.ndim < 3:
37
+ raise ValueError(f'Input image must have at least 3 dimensions, got {image.ndim}')
38
+
39
+ w, h = image.shape[-2], image.shape[-1]
40
+ a, b = patch_size
41
+
42
+ pad_w = (a - w % a) % a
43
+ pad_h = (b - h % b) % b
44
+
45
+ image_padded = F.pad(image, (0, pad_h, 0, pad_w))
46
+
47
+ w_padded = w + pad_w
48
+ h_padded = h + pad_h
49
+
50
+ num_w = w_padded // a
51
+ num_h = h_padded // b
52
+
53
+ mask = torch.ones((*image.shape[:-3], 1, w, h), dtype=image.dtype, device=image.device)
54
+ mask_padded = F.pad(mask, (0, pad_h, 0, pad_w), value=0)
55
+
56
+ return PatchOutput(
57
+ output=image_padded,
58
+ mask=mask_padded,
59
+ patch_size=patch_size,
60
+ num_patches=(num_w, num_h)
61
+ )
62
+
63
+ def split_to_patches(image: torch.Tensor, patch_size: Tuple[int, int]) -> PatchOutput:
64
+ '''将图像张量分割成多个子图像,支持自动填充。
65
+
66
+ 此函数接收形状为 [..., channels, width, height] 的图像张量,
67
+ 并将其划分为形状为 [channels, patch_width, patch_height] 的补丁。
68
+ 如果图像尺寸不能被 patch_size 整除,则在右侧和底部进行零填充。
69
+ 结果张量的形状为 [..., num_patches_total, channels, patch_width, patch_height]。
70
+
71
+ Args:
72
+ image (torch.Tensor): 输入图像张量,形状为 [..., channels, w, h]。
73
+ 最后两个维度被视为空间维度(宽度,高度)。
74
+ patch_size (Tuple[int, int]): 表示补丁大小的元组 (a, b),其中 'a' 对应于
75
+ 宽度维度,'b' 对应于高度维度。
76
+
77
+ Returns:
78
+ PatchOutput: 包含以下字段的数据类:
79
+ - output (torch.Tensor): 形状为 [..., num_w * num_h, channels, a, b] 的补丁张量。
80
+ - mask (torch.Tensor): 形状为 [..., num_w * num_h, 1, a, b] 的掩码张量,
81
+ 有效区域为 1,填充区域为 0。
82
+ - patch_size (Tuple[int, int]): 输入的补丁大小 (a, b)。
83
+ - num_patches (Tuple[int, int]): 宽度和高度方向的补丁数量 (num_w, num_h)。
84
+
85
+ Raises:
86
+ ValueError: 如果输入图像维度少于 3。
87
+ '''
88
+ padded_output = pad_to_patch_size(image, patch_size)
89
+ image_padded = padded_output.output
90
+ mask_padded = padded_output.mask
91
+ num_w, num_h = padded_output.num_patches
92
+ a, b = patch_size
93
+
94
+ def split(x, nw, nh):
95
+ # x shape: [..., C, W, H]
96
+ reshaped = x.view(*x.shape[:-2], nw, a, nh, b)
97
+ # permute to [..., nw, nh, C, a, b]
98
+ permuted = reshaped.permute(
99
+ *range(reshaped.ndim - 5),
100
+ reshaped.ndim - 4, # nw
101
+ reshaped.ndim - 2, # nh
102
+ reshaped.ndim - 5, # C
103
+ reshaped.ndim - 3, # a
104
+ reshaped.ndim - 1 # b
105
+ )
106
+ return permuted.reshape(*x.shape[:-3], nw * nh, x.shape[-3], a, b)
107
+
108
+ output_patches = split(image_padded, num_w, num_h)
109
+ mask_patches = split(mask_padded, num_w, num_h)
110
+
111
+ return PatchOutput(
112
+ output=output_patches,
113
+ mask=mask_patches,
114
+ patch_size=patch_size,
115
+ num_patches=(num_w, num_h)
116
+ )
117
+
118
+ def reconstruct_from_patches(patches: torch.Tensor, num_patches: Tuple[int, int], mask: torch.Tensor = None) -> torch.Tensor:
119
+ '''从补丁重建图像。
120
+
121
+ 此函数是 split_to_patches 的逆操作。它将补丁张量重新组合成原始图像。
122
+ 如果提供了 mask,则会根据 mask 去除 padding。
123
+
124
+ Args:
125
+ patches (torch.Tensor): 形状为 [..., num_patches, channels, patch_width, patch_height] 的补丁张量。
126
+ num_patches (Tuple[int, int]): 宽度和高度方向的补丁数量 (num_w, num_h)。
127
+ mask (torch.Tensor, optional): 用于去除 padding 的掩码,形状与 patches 相同但通道数为 1。
128
+ 如果提供,将根据掩码裁剪重建后的图像以去除 padding。
129
+
130
+ Returns:
131
+ torch.Tensor: 重建后的图像张量,形状为 [..., channels, width, height]。
132
+ '''
133
+ nw, nh = num_patches
134
+ if patches.shape[-4] != nw * nh:
135
+ raise ValueError(f"Number of patches in tensor ({patches.shape[-4]}) does not match num_patches argument ({nw} * {nh} = {nw*nh})")
136
+
137
+ a, b = patches.shape[-2], patches.shape[-1]
138
+
139
+ def unsplit(x):
140
+ # x: [..., nw*nh, C, a, b]
141
+ reshaped = x.view(*x.shape[:-4], nw, nh, *x.shape[-3:])
142
+ # permute to [..., C, nw, a, nh, b]
143
+ permuted = reshaped.permute(
144
+ *range(reshaped.ndim - 5),
145
+ reshaped.ndim - 3, # C
146
+ reshaped.ndim - 5, # nw
147
+ reshaped.ndim - 2, # a
148
+ reshaped.ndim - 4, # nh
149
+ reshaped.ndim - 1 # b
150
+ )
151
+ return permuted.reshape(*x.shape[:-4], x.shape[-3], nw * a, nh * b)
152
+
153
+ reconstructed = unsplit(patches)
154
+
155
+ if mask is not None:
156
+ reconstructed_mask = unsplit(mask)
157
+ m = reconstructed_mask.view(-1, reconstructed_mask.shape[-2], reconstructed_mask.shape[-1])
158
+
159
+ valid_h = (m[0, 0, :] > 0.5).sum().item()
160
+ valid_w = (m[0, :, 0] > 0.5).sum().item()
161
+
162
+ reconstructed = reconstructed[..., :int(valid_w), :int(valid_h)]
163
+
164
+ return reconstructed
@@ -4,13 +4,8 @@ import re
4
4
  import torch
5
5
  import torch.nn as nn
6
6
  from torch.nn.init import _calculate_fan_in_and_fan_out
7
-
8
- try:
9
- from rich.console import Console
10
- from rich.table import Table
11
- RICH_AVAILABLE = True
12
- except ImportError:
13
- RICH_AVAILABLE = False
7
+ from rich.console import Console
8
+ from rich.table import Table
14
9
 
15
10
  def _no_grad_trunc_normal_(tensor, mean, std, a, b):
16
11
  '''截断正态分布初始化的辅助函数,在无梯度模式下运行。
@@ -68,11 +63,63 @@ def constant_init(module, val, bias=0):
68
63
  val (float): 权重的常数值。
69
64
  bias (float): 偏置的常数值。
70
65
  '''
66
+ if isinstance(module, (nn.Parameter, torch.Tensor)):
67
+ nn.init.constant_(module, val)
68
+ return
69
+
71
70
  if hasattr(module, 'weight') and module.weight is not None:
72
71
  nn.init.constant_(module.weight, val)
73
72
  if hasattr(module, 'bias') and module.bias is not None:
74
73
  nn.init.constant_(module.bias, bias)
75
74
 
75
+ def _init_tensor_impl(tensor, method, distribution, a, mode, nonlinearity, gain, std, trunc_a, trunc_b):
76
+ '''内部函数:对单个张量应用初始化方法。'''
77
+ info = ""
78
+ if method == 'kaiming':
79
+ if distribution == 'uniform':
80
+ nn.init.kaiming_uniform_(
81
+ tensor, a=a, mode=mode, nonlinearity=nonlinearity)
82
+ else:
83
+ nn.init.kaiming_normal_(
84
+ tensor, a=a, mode=mode, nonlinearity=nonlinearity)
85
+ info = f'Kaiming ({distribution}), mode={mode}, nonlin={nonlinearity}'
86
+
87
+ elif method == 'xavier':
88
+ if distribution == 'uniform':
89
+ nn.init.xavier_uniform_(tensor, gain=gain)
90
+ else:
91
+ nn.init.xavier_normal_(tensor, gain=gain)
92
+ info = f'Xavier ({distribution}), gain={gain}'
93
+
94
+ elif method == 'c2_xavier':
95
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
96
+ c2_std = math.sqrt(1.0 / float(fan_in))
97
+ nn.init.normal_(tensor, mean=0.0, std=c2_std)
98
+ info = f'C2 Xavier (Normal), std={c2_std:.4f}'
99
+
100
+ elif method == 'orthogonal':
101
+ nn.init.orthogonal_(tensor, gain=gain)
102
+ info = f'Orthogonal, gain={gain}'
103
+
104
+ elif method == 'trunc_normal':
105
+ trunc_normal_(
106
+ tensor, mean=0., std=std, a=trunc_a, b=trunc_b)
107
+ info = f'Trunc Normal, std={std}, a={trunc_a}, b={trunc_b}'
108
+
109
+ elif method == 'normal':
110
+ nn.init.normal_(tensor, mean=0., std=std)
111
+ info = f'Normal, std={std}'
112
+
113
+ elif method == 'constant':
114
+ nn.init.constant_(tensor, val=gain)
115
+ info = f'Constant, val={gain}'
116
+
117
+ else:
118
+ nn.init.xavier_uniform_(tensor, gain=gain)
119
+ info = f'Xavier (Uniform) [Default], gain={gain}'
120
+
121
+ return info
122
+
76
123
  def init_weights(module, method='kaiming', distribution='normal', bias=0,
77
124
  a=0, mode='fan_out', nonlinearity='relu',
78
125
  gain=1,
@@ -98,45 +145,56 @@ def init_weights(module, method='kaiming', distribution='normal', bias=0,
98
145
  std (float): Normal/Truncated Normal 的标准差。
99
146
  trunc_a (float): Truncated Normal 的下界。
100
147
  trunc_b (float): Truncated Normal 的上界。
101
- '''
102
- if hasattr(module, 'weight') and module.weight is not None:
103
- if method == 'kaiming':
104
- if distribution == 'uniform':
105
- nn.init.kaiming_uniform_(
106
- module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
107
- else:
108
- nn.init.kaiming_normal_(
109
- module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
110
-
111
- elif method == 'xavier':
112
- if distribution == 'uniform':
113
- nn.init.xavier_uniform_(module.weight, gain=gain)
114
- else:
115
- nn.init.xavier_normal_(module.weight, gain=gain)
116
148
 
117
- elif method == 'c2_xavier':
118
- fan_in, fan_out = _calculate_fan_in_and_fan_out(module.weight)
119
- c2_std = math.sqrt(1.0 / float(fan_in))
120
- nn.init.normal_(module.weight, mean=0.0, std=c2_std)
121
-
122
- elif method == 'orthogonal':
123
- nn.init.orthogonal_(module.weight, gain=gain)
124
-
125
- elif method == 'trunc_normal':
126
- trunc_normal_(
127
- module.weight, mean=0., std=std, a=trunc_a, b=trunc_b)
149
+ Returns:
150
+ str: 初始化详情字符串,如果未执行初始化则返回 None。
151
+ '''
152
+ # 1. 直接处理 Parameter/Tensor
153
+ if isinstance(module, (nn.Parameter, torch.Tensor)):
154
+ return _init_tensor_impl(module, method, distribution, a, mode, nonlinearity, gain, std, trunc_a, trunc_b)
155
+
156
+ # 2. 处理 Module
157
+ info_parts = []
158
+ handled_params = set()
159
+
160
+ def init_and_record(tensor, name, is_bias=False):
161
+ if id(tensor) in handled_params:
162
+ return
128
163
 
129
- elif method == 'normal':
130
- nn.init.normal_(module.weight, mean=0., std=std)
131
-
132
- elif method == 'constant':
133
- nn.init.constant_(module.weight, val=gain)
134
-
164
+ if is_bias:
165
+ nn.init.constant_(tensor, bias)
166
+ # 简化输出:如果是标准的 bias 且值为 0,可能不需要太详细,但为了清晰还是保留
167
+ info = f"bias={bias}" if name == 'bias' else f"{name}: Constant({bias})"
135
168
  else:
136
- nn.init.xavier_uniform_(module.weight, gain=gain)
169
+ info = _init_tensor_impl(tensor, method, distribution, a, mode, nonlinearity, gain, std, trunc_a, trunc_b)
170
+ if name != 'weight':
171
+ info = f"{name}: {info}"
172
+
173
+ info_parts.append(info)
174
+ handled_params.add(id(tensor))
175
+
176
+ # A. 优先处理标准属性 'weight'
177
+ if hasattr(module, 'weight') and module.weight is not None:
178
+ init_and_record(module.weight, 'weight', is_bias=False)
137
179
 
180
+ # B. 优先处理标准属性 'bias'
138
181
  if hasattr(module, 'bias') and module.bias is not None:
139
- nn.init.constant_(module.bias, bias)
182
+ init_and_record(module.bias, 'bias', is_bias=True)
183
+
184
+ # C. 遍历所有注册参数 (处理自定义名称)
185
+ # recurse=False 确保只处理当前模块的直接参数
186
+ for name, param in module.named_parameters(recurse=False):
187
+ if id(param) in handled_params:
188
+ continue
189
+
190
+ # 启发式规则:维度 < 2 视为偏置类参数,否则视为权重类参数
191
+ is_bias_like = param.ndim < 2
192
+ init_and_record(param, name, is_bias=is_bias_like)
193
+
194
+ if not info_parts:
195
+ return None
196
+
197
+ return ", ".join(info_parts)
140
198
 
141
199
  def init_layer_norm(module, weight=1.0, bias=0.0):
142
200
  '''初始化 LayerNorm 或 GroupNorm 模块。
@@ -145,11 +203,21 @@ def init_layer_norm(module, weight=1.0, bias=0.0):
145
203
  module (nn.Module): 归一化模块。
146
204
  weight (float): 权重的初始值 (gamma)。
147
205
  bias (float): 偏置的初始值 (beta)。
206
+
207
+ Returns:
208
+ str: 初始化详情字符串。
148
209
  '''
210
+ initialized = False
149
211
  if hasattr(module, 'weight') and module.weight is not None:
150
212
  nn.init.constant_(module.weight, weight)
213
+ initialized = True
151
214
  if hasattr(module, 'bias') and module.bias is not None:
152
215
  nn.init.constant_(module.bias, bias)
216
+ initialized = True
217
+
218
+ if initialized:
219
+ return f'Norm (w={weight}, b={bias})'
220
+ return None
153
221
 
154
222
  def init_embedding(module, init_method='normal', std=0.02, a=0., b=1., padding_idx=None):
155
223
  '''初始化 Embedding 层。
@@ -161,17 +229,32 @@ def init_embedding(module, init_method='normal', std=0.02, a=0., b=1., padding_i
161
229
  a (float): 均匀分布的下界或截断正态分布的下界。
162
230
  b (float): 均匀分布的上界或截断正态分布的上界。
163
231
  padding_idx (int, optional): 如果指定,padding 索引的权重将被初始化为 0。
232
+
233
+ Returns:
234
+ str: 初始化详情字符串。
164
235
  '''
165
- if hasattr(module, 'weight') and module.weight is not None:
166
- if init_method == 'normal':
167
- nn.init.normal_(module.weight, mean=0., std=std)
168
- elif init_method == 'trunc_normal':
169
- trunc_normal_(module.weight, mean=0., std=std, a=a, b=b)
170
- elif init_method == 'uniform':
171
- nn.init.uniform_(module.weight, a=a, b=b)
236
+ if not (hasattr(module, 'weight') and module.weight is not None):
237
+ return None
238
+
239
+ info = ""
240
+ if init_method == 'normal':
241
+ nn.init.normal_(module.weight, mean=0., std=std)
242
+ info = f'Normal (std={std})'
243
+ elif init_method == 'trunc_normal':
244
+ trunc_normal_(module.weight, mean=0., std=std, a=a, b=b)
245
+ info = f'Trunc Normal (std={std}, [{a}, {b}])'
246
+ elif init_method == 'uniform':
247
+ nn.init.uniform_(module.weight, a=a, b=b)
248
+ info = f'Uniform ([{a}, {b}])'
249
+ else:
250
+ nn.init.normal_(module.weight, mean=0., std=std)
251
+ info = f'Normal (std={std})'
172
252
 
173
253
  if padding_idx is not None:
174
254
  module.weight.data[padding_idx].zero_()
255
+ info += f', pad_idx={padding_idx}'
256
+
257
+ return info
175
258
 
176
259
  def init_weights_transformer(model, n_layer=None, initializer_range=0.02,
177
260
  residual_proj_names=('linear_out', 'fc2', 'c_proj'),
@@ -277,6 +360,25 @@ class WeightInitializer:
277
360
  '''
278
361
  init_info = []
279
362
 
363
+ # 处理单个 Parameter/Tensor
364
+ if isinstance(model, (nn.Parameter, torch.Tensor)):
365
+ info = init_weights(
366
+ model,
367
+ method=self.method,
368
+ distribution=self.distribution,
369
+ bias=self.init_bias,
370
+ mode=self.mode,
371
+ nonlinearity=self.nonlinearity,
372
+ std=self.std,
373
+ trunc_a=self.trunc_a,
374
+ trunc_b=self.trunc_b
375
+ )
376
+ if info:
377
+ init_info.append(('Parameter/Tensor', type(model).__name__, info))
378
+ if verbose:
379
+ _print_init_info(init_info)
380
+ return
381
+
280
382
  for name, module in model.named_modules():
281
383
  current_config = {}
282
384
  if override:
@@ -291,8 +393,19 @@ class WeightInitializer:
291
393
  nonlinearity = current_config.get('nonlinearity', self.nonlinearity)
292
394
  std = current_config.get('std', self.std)
293
395
 
294
- if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)):
295
- init_weights(
396
+ info = None
397
+
398
+ if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm, nn.BatchNorm1d)):
399
+ info = init_layer_norm(
400
+ module, weight=self.init_norm_weight, bias=self.init_norm_bias)
401
+
402
+ elif isinstance(module, nn.Embedding):
403
+ emb_method = method if method in ['normal', 'trunc_normal', 'uniform'] else 'normal'
404
+ info = init_embedding(module, init_method=emb_method, std=std)
405
+
406
+ else:
407
+ # 尝试通用初始化 (Linear, Conv, 或其他带 weight/bias 的层)
408
+ info = init_weights(
296
409
  module,
297
410
  method=method,
298
411
  distribution=distribution,
@@ -303,31 +416,15 @@ class WeightInitializer:
303
416
  trunc_a=self.trunc_a,
304
417
  trunc_b=self.trunc_b
305
418
  )
306
- info = f'{method} ({distribution})'
307
- if method == 'kaiming':
308
- info += f', mode={mode}, nonlin={nonlinearity}'
309
- elif method == 'normal':
310
- info += f', std={std}'
311
- init_info.append((name, module.__class__.__name__, info))
312
-
313
- elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm, nn.BatchNorm1d)):
314
- init_layer_norm(
315
- module, weight=self.init_norm_weight, bias=self.init_norm_bias)
316
- init_info.append((name, module.__class__.__name__, 'Norm (1/0)'))
317
419
 
318
- elif isinstance(module, nn.Embedding):
319
- init_embedding(
320
- module,
321
- init_method=method if method in ['normal', 'trunc_normal', 'uniform'] else 'normal',
322
- std=std
323
- )
324
- init_info.append((name, 'Embedding', f'{method} (std={std})'))
420
+ if info:
421
+ init_info.append((name, module.__class__.__name__, info))
325
422
 
326
423
  if verbose:
327
424
  _print_init_info(init_info)
328
425
 
329
426
  def _print_init_info(init_info):
330
- '''打印初始化信息的辅助函数,支持 rich 美化。
427
+ '''打印初始化信息的辅助函数,使用 rich 美化。
331
428
 
332
429
  Args:
333
430
  init_info (list): 包含 (layer_name, module_type, details) 元组的列表。
@@ -335,22 +432,16 @@ def _print_init_info(init_info):
335
432
  if not init_info:
336
433
  return
337
434
 
338
- if RICH_AVAILABLE:
339
- console = Console()
340
- table = Table(title="Weight Initialization Report", show_header=True, header_style="bold magenta")
341
- table.add_column("Layer Name", style="cyan")
342
- table.add_column("Module Type", style="green")
343
- table.add_column("Initialization Details", style="yellow")
435
+ console = Console()
436
+ table = Table(title="Weight Initialization Report", show_header=True, header_style="bold magenta")
437
+ table.add_column("Layer Name", style="cyan")
438
+ table.add_column("Module Type", style="green")
439
+ table.add_column("Initialization Details", style="yellow")
344
440
 
345
- for name, type_name, details in init_info:
346
- table.add_row(name, type_name, details)
347
-
348
- console.print(table)
349
- else:
350
- print(f"{'Layer Name':<40} | {'Module Type':<20} | {'Initialization Details'}")
351
- print("-" * 90)
352
- for name, type_name, details in init_info:
353
- print(f"{name:<40} | {type_name:<20} | {details}")
441
+ for name, type_name, details in init_info:
442
+ table.add_row(str(name), str(type_name), str(details))
443
+
444
+ console.print(table)
354
445
 
355
446
  def initialize_weights(model, method='kaiming', override=None, verbose=False, **kwargs):
356
447
  '''初始化模型权重的便捷函数。
@@ -398,6 +489,9 @@ class AutoInitializer:
398
489
  'transformer_detected': False
399
490
  }
400
491
 
492
+ if isinstance(self.model, (nn.Parameter, torch.Tensor)):
493
+ return stats
494
+
401
495
  # 简单的深度估计:计算包含参数的层数
402
496
  param_layers = [m for m in self.model.modules() if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv1d))]
403
497
  stats['depth'] = len(param_layers)
@@ -471,16 +565,12 @@ class AutoInitializer:
471
565
  method, nonlinearity, override = self.recommend_config()
472
566
 
473
567
  if verbose:
474
- if RICH_AVAILABLE:
475
- console = Console()
476
- console.print(f"[bold cyan]Auto Initialization Analysis:[/bold cyan]")
477
- console.print(f" Depth: {self.stats['depth']}")
478
- console.print(f" Activations: {self.stats['activations']}")
479
- console.print(f" Transformer Detected: {self.stats['transformer_detected']}")
480
- console.print(f"[bold green]Recommended Strategy:[/bold green] {method} (nonlin={nonlinearity})")
481
- else:
482
- print(f"Auto Init: Depth={self.stats['depth']}, Acts={self.stats['activations']}")
483
- print(f"Strategy: {method}, {nonlinearity}")
568
+ console = Console()
569
+ console.print(f"[bold cyan]Auto Initialization Analysis:[/bold cyan]")
570
+ console.print(f" Depth: {self.stats['depth']}")
571
+ console.print(f" Activations: {self.stats['activations']}")
572
+ console.print(f" Transformer Detected: {self.stats['transformer_detected']}")
573
+ console.print(f"[bold green]Recommended Strategy:[/bold green] {method} (nonlin={nonlinearity})")
484
574
 
485
575
  initialize_weights(
486
576
  self.model,
orbit/utils/layer_io.py CHANGED
@@ -2,7 +2,10 @@ import torch
2
2
  import torch.nn as nn
3
3
  from typing import Union
4
4
 
5
- def get_module_by_name(model: nn.Module, name: str) -> nn.Module:
5
+ from safetensors.torch import save_file as safe_save_file
6
+ from safetensors.torch import load_file as safe_load_file
7
+
8
+ def get_model_by_name(model: nn.Module, name: str) -> nn.Module:
6
9
  '''通过名称获取模型的子模块。
7
10
 
8
11
  Args:
@@ -23,7 +26,7 @@ def get_module_by_name(model: nn.Module, name: str) -> nn.Module:
23
26
  module = getattr(module, n)
24
27
  return module
25
28
 
26
- def save_layer_weights(model: nn.Module, layer_name: str, file_path: str) -> None:
29
+ def save_layer(model: nn.Module, layer_name: str, file_path: str) -> None:
27
30
  '''保存模型指定层的权重到文件。
28
31
 
29
32
  Args:
@@ -31,10 +34,13 @@ def save_layer_weights(model: nn.Module, layer_name: str, file_path: str) -> Non
31
34
  layer_name (str): 要保存权重的层名称。
32
35
  file_path (str): 保存路径。
33
36
  '''
34
- module = get_module_by_name(model, layer_name)
35
- torch.save(module.state_dict(), file_path)
37
+ module = get_model_by_name(model, layer_name)
38
+ if file_path.endswith('.safetensors'):
39
+ safe_save_file(module.state_dict(), file_path)
40
+ else:
41
+ torch.save(module.state_dict(), file_path)
36
42
 
37
- def load_layer_weights(
43
+ def load_layer(
38
44
  model: nn.Module,
39
45
  layer_name: str,
40
46
  file_path: str,
@@ -50,6 +56,59 @@ def load_layer_weights(
50
56
  strict (bool): 是否严格匹配键值。默认为 True。
51
57
  map_location (str or torch.device): 加载位置。默认为 'cpu'。
52
58
  '''
53
- state_dict = torch.load(file_path, map_location=map_location)
54
- module = get_module_by_name(model, layer_name)
59
+ if file_path.endswith('.safetensors'):
60
+ state_dict = safe_load_file(file_path, device=str(map_location))
61
+ else:
62
+ state_dict = torch.load(file_path, map_location=map_location)
63
+
64
+ if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
65
+ state_dict = state_dict['model_state_dict']
66
+
67
+ module = get_model_by_name(model, layer_name)
68
+
69
+ prefix = layer_name + '.'
70
+ if any(k.startswith(prefix) for k in state_dict.keys()):
71
+ new_state_dict = {}
72
+ for k, v in state_dict.items():
73
+ if k.startswith(prefix):
74
+ new_key = k[len(prefix):]
75
+ new_state_dict[new_key] = v
76
+ state_dict = new_state_dict
77
+
55
78
  module.load_state_dict(state_dict, strict=strict)
79
+
80
+ def save_model(model: nn.Module, file_path: str) -> None:
81
+ '''保存整个模型的权重到文件。
82
+
83
+ Args:
84
+ model (nn.Module): 目标模型。
85
+ file_path (str): 保存路径。
86
+ '''
87
+ if file_path.endswith('.safetensors'):
88
+ safe_save_file(model.state_dict(), file_path)
89
+ else:
90
+ torch.save(model.state_dict(), file_path)
91
+
92
+ def load_model(
93
+ model: nn.Module,
94
+ file_path: str,
95
+ strict: bool = True,
96
+ map_location: Union[str, torch.device] = 'cpu'
97
+ ) -> None:
98
+ '''从文件加载权重到整个模型。
99
+
100
+ Args:
101
+ model (nn.Module): 目标模型。
102
+ file_path (str): 权重文件路径。
103
+ strict (bool): 是否严格匹配键值。默认为 True。
104
+ map_location (str or torch.device): 加载位置。默认为 'cpu'。
105
+ '''
106
+ if file_path.endswith('.safetensors'):
107
+ state_dict = safe_load_file(file_path, device=str(map_location))
108
+ else:
109
+ state_dict = torch.load(file_path, map_location=map_location)
110
+
111
+ if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
112
+ state_dict = state_dict['model_state_dict']
113
+
114
+ model.load_state_dict(state_dict, strict=strict)