orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. orbit/__init__.py +3 -1
  2. orbit/callback.py +4 -3
  3. orbit/dataset/__init__.py +1 -0
  4. orbit/dataset/cogn.py +138 -0
  5. orbit/dataset/data/cogn_en.jsonl +45 -0
  6. orbit/dataset/data/cogn_zh.jsonl +113 -0
  7. orbit/engine.py +210 -146
  8. orbit/kit/__init__.py +2 -0
  9. orbit/kit/interface.py +154 -0
  10. orbit/kit/wrapper.py +157 -0
  11. orbit/model/__init__.py +5 -0
  12. orbit/model/base.py +125 -0
  13. orbit/model/block/__init__.py +34 -0
  14. orbit/model/block/attention.py +265 -0
  15. orbit/model/block/bio.py +537 -0
  16. orbit/model/block/codebook.py +122 -0
  17. orbit/model/block/conv.py +505 -0
  18. orbit/model/block/embedding.py +252 -0
  19. orbit/model/block/film.py +176 -0
  20. orbit/model/block/fusion.py +335 -0
  21. orbit/model/block/gate.py +334 -0
  22. orbit/model/block/lora.py +776 -0
  23. orbit/model/block/mlp.py +68 -0
  24. orbit/model/block/moe.py +94 -0
  25. orbit/model/block/tcn.py +99 -0
  26. orbit/model/config.py +62 -0
  27. orbit/model/kit/__init__.py +6 -0
  28. orbit/model/kit/discriminator.py +46 -0
  29. orbit/model/kit/losses.py +193 -0
  30. orbit/model/motif/__init__.py +0 -0
  31. orbit/model/motif/vision/__init__.py +0 -0
  32. orbit/model/motif/vision/v1.py +645 -0
  33. orbit/model/registry.py +53 -0
  34. orbit/optim/__init__.py +2 -2
  35. orbit/optim/sam.py +10 -3
  36. orbit/plugin/__init__.py +12 -8
  37. orbit/plugin/board.py +1 -2
  38. orbit/plugin/checkpoint.py +137 -62
  39. orbit/plugin/classification.py +2 -2
  40. orbit/plugin/display_model.py +1 -2
  41. orbit/plugin/early_stopping.py +1 -2
  42. orbit/plugin/ema.py +1 -2
  43. orbit/plugin/gradient_accumulation.py +1 -2
  44. orbit/plugin/lora.py +346 -0
  45. orbit/plugin/memory_estimator.py +1 -2
  46. orbit/plugin/warmup.py +1 -2
  47. orbit/utils/__init__.py +24 -1
  48. orbit/utils/cuda.py +10 -0
  49. orbit/utils/freeze.py +61 -17
  50. orbit/utils/image.py +164 -0
  51. orbit/utils/initialization.py +184 -94
  52. orbit/utils/layer_io.py +66 -7
  53. orbit/utils/lora.py +480 -0
  54. orbit/utils/moe.py +55 -0
  55. orbit/utils/seed.py +3 -19
  56. orbit/utils/sft.py +93 -0
  57. orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
  58. orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
  59. orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
  60. orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
  61. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
  62. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
orbit/plugin/lora.py ADDED
@@ -0,0 +1,346 @@
1
+ import os
2
+ import torch
3
+ import inspect
4
+ import json
5
+
6
+ from typing import Optional, List, TYPE_CHECKING
7
+ from orbit.callback import Event
8
+ from orbit.plugin.checkpoint import Checkpoint, safe_save_file, safe_load_file, safe_open
9
+ from orbit.utils.lora import inject_lora, freeze_backbone_only
10
+
11
+ if TYPE_CHECKING: from orbit.engine import Engine
12
+
13
+ class LoRA(Checkpoint):
14
+ '''LoRA 插件:集成 LoRA 注入、冻结、轻量化保存与加载功能。
15
+
16
+ 该插件继承自 Checkpoint,在训练开始时自动将模型转换为 LoRA 模型,
17
+ 重建优化器以适应新参数,并提供仅保存训练参数的轻量化 Checkpoint 功能。
18
+ '''
19
+ def __init__(
20
+ self,
21
+ name: str = "lora_model",
22
+ path: str = "checkpoints",
23
+ # LoRA 参数
24
+ r: int = 8,
25
+ lora_alpha: int = 16,
26
+ lora_dropout: float = 0.05,
27
+ target_names: Optional[List[str]] = None,
28
+ exclude_names: Optional[List[str]] = None,
29
+ unlock_head_keywords: Optional[List[str]] = None,
30
+ gate: bool = False,
31
+ dora: bool = False,
32
+ # Checkpoint 参数
33
+ monitor: str = 'val_loss',
34
+ mode: str = 'min',
35
+ save_top_k: int = 1,
36
+ save_last: bool = True,
37
+ every_n_train_steps: Optional[int] = None,
38
+ use_safetensors: bool = False,
39
+ verbose: bool = True
40
+ ):
41
+ '''初始化 LoRA 插件。
42
+
43
+ Args:
44
+ name (str): 模型名称前缀。
45
+ path (str): Checkpoint 保存目录。
46
+ r (int): LoRA 秩。
47
+ lora_alpha (int): LoRA 缩放系数。
48
+ lora_dropout (float): LoRA Dropout。
49
+ target_names (list, optional): 仅注入包含这些名称的层。
50
+ exclude_names (list, optional): 排除包含这些名称的层。
51
+ unlock_head_keywords (list, optional): 除了 LoRA 层外,还需要解冻的层关键字(如分类头)。
52
+ gate (bool): 是否使用 Gated LoRA。
53
+ dora (bool): 是否使用 DoRA。
54
+ monitor (str): 监控指标。
55
+ mode (str): 监控指标模式 ('min'/'max')。
56
+ save_top_k (int): 保存最佳模型数量。
57
+ save_last (bool): 是否保存最后的模型。
58
+ every_n_train_steps (int, optional): 每 N 步保存一次。
59
+ use_safetensors (bool): 是否使用 safetensors 格式保存。
60
+ verbose (bool): 是否打印日志。
61
+ '''
62
+ # 初始化 Checkpoint,强制 save_weights_only=False 以保留训练状态,
63
+ # 但我们会在 _save 中自定义过滤逻辑。
64
+ super().__init__(
65
+ name=name,
66
+ path=path,
67
+ save_weights_only=False,
68
+ monitor=monitor,
69
+ mode=mode,
70
+ save_top_k=save_top_k,
71
+ save_last=save_last,
72
+ every_n_train_steps=every_n_train_steps,
73
+ use_safetensors=use_safetensors,
74
+ verbose=verbose
75
+ )
76
+
77
+ self.r = r
78
+ self.lora_alpha = lora_alpha
79
+ self.lora_dropout = lora_dropout
80
+ self.target_names = target_names
81
+ self.exclude_names = exclude_names
82
+ self.unlock_head_keywords = unlock_head_keywords
83
+ self.gate = gate
84
+ self.dora = dora
85
+
86
+ self.injected = False
87
+
88
+ def on_init(self, event: Event):
89
+ engine = event.engine
90
+
91
+ # 0. 冲突检测:检查是否存在其他 Checkpoint 插件
92
+ other_checkpoints = [p for p in engine.plugins if isinstance(p, Checkpoint) and p is not self]
93
+ if other_checkpoints:
94
+ engine.print("[yellow]Warning: Multiple Checkpoint plugins detected. Since 'LoRA' inherits from 'Checkpoint', using both may cause conflicts (e.g. double saving). Suggest removing the standard 'Checkpoint' plugin.[/]", plugin='LoRA')
95
+
96
+ model = engine.unwrap_model()
97
+
98
+ # 1. 注入 LoRA 并冻结骨干
99
+ if not self.injected:
100
+ engine.print(f"[cyan]Injecting LoRA (r={self.r}, alpha={self.lora_alpha})...[/]", plugin='LoRA')
101
+ inject_lora(
102
+ model,
103
+ r=self.r,
104
+ lora_alpha=self.lora_alpha,
105
+ lora_dropout=self.lora_dropout,
106
+ gate=self.gate,
107
+ dora=self.dora,
108
+ target_names=self.target_names,
109
+ exclude_names=self.exclude_names
110
+ )
111
+
112
+ freeze_backbone_only(
113
+ model,
114
+ unlock_head_keywords=self.unlock_head_keywords,
115
+ verbose=self.verbose
116
+ )
117
+
118
+ model.to(engine.device)
119
+
120
+ # 2. 重建 Optimizer
121
+ # 由于参数集发生变化(新增了 LoRA 参数,冻结了大部分参数),旧的优化器无法使用。
122
+ # 我们尝试使用旧优化器的类和 defaults 重新初始化。
123
+ if engine.optimizer:
124
+ old_opt = engine.optimizer
125
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
126
+
127
+ if not trainable_params:
128
+ engine.print("[red]Warning: No trainable parameters found after LoRA injection![/]", plugin='LoRA')
129
+ else:
130
+ opt_cls = old_opt.__class__
131
+ defaults = old_opt.defaults
132
+
133
+ # 过滤掉不在构造函数中的参数,防止某些库(如 transformers)在 defaults 中添加了额外元数据导致重建失败
134
+ sig = inspect.signature(opt_cls.__init__)
135
+ has_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
136
+ if has_kwargs:
137
+ filtered_defaults = defaults
138
+ else:
139
+ filtered_defaults = {k: v for k, v in defaults.items() if k in sig.parameters}
140
+
141
+ engine.print(f"[cyan]Re-initializing Optimizer {opt_cls.__name__} for {len(trainable_params)} trainable groups...[/]", plugin='LoRA')
142
+
143
+ # 创建新优化器
144
+ if opt_cls.__name__ == 'SAM':
145
+ base_opt_cls = old_opt.base_optimizer.__class__
146
+ new_opt = opt_cls(trainable_params, base_optimizer=base_opt_cls, **filtered_defaults)
147
+ elif opt_cls.__name__ == 'Muon':
148
+ muon_params = [p for p in trainable_params if p.ndim == 2]
149
+ adamw_params = [p for p in trainable_params if p.ndim != 2]
150
+ new_opt = opt_cls(muon_params=muon_params, adamw_params=adamw_params, **filtered_defaults)
151
+ else:
152
+ new_opt = opt_cls(trainable_params, **filtered_defaults)
153
+
154
+ # 替换 Engine 中的优化器
155
+ engine.optimizer = new_opt
156
+
157
+ # 处理 Scheduler 的关联
158
+ if engine.scheduler:
159
+ if hasattr(engine.scheduler, 'optimizer'):
160
+ engine.scheduler.optimizer = new_opt
161
+ engine.print("[yellow]Note: Scheduler optimizer reference updated.[/]", plugin='LoRA')
162
+ else:
163
+ engine.print("[yellow]Warning: Scheduler exists but cannot auto-update optimizer reference. Please verify scheduler behavior.[/]", plugin='LoRA')
164
+
165
+ self.injected = True
166
+
167
+ # 3. 执行 Checkpoint 的初始化 (创建目录, 尝试 Resume)
168
+ super().on_init(event)
169
+
170
+ def _save(self, engine: 'Engine', filename: str, verbose: bool = True, is_step: bool = False):
171
+ """重写保存逻辑,仅保存可训练参数 (LoRA + Heads) 和训练状态。"""
172
+
173
+ if self.monitor:
174
+ engine.meta[self._meta_key] = {'best_k_models': self.best_k_models}
175
+
176
+ raw_model = engine.unwrap_model()
177
+ full_state_dict = raw_model.state_dict()
178
+ lora_state_dict = {}
179
+
180
+ # 筛选: 保存 requires_grad 的参数以及 LoRA/DoRA 相关的键
181
+ for name, param in raw_model.named_parameters():
182
+ if param.requires_grad:
183
+ lora_state_dict[name] = full_state_dict[name]
184
+
185
+ # 确保 buffers (如 BN running stats) 在解冻层中也被保存
186
+ # 同时也保存所有名字中带 lora/dora 的 buffer (虽然通常它们没有 buffer)
187
+ for key, value in full_state_dict.items():
188
+ if 'lora_' in key or 'dora_' in key:
189
+ if key not in lora_state_dict:
190
+ lora_state_dict[key] = value
191
+
192
+ lora_config = {
193
+ 'r': self.r,
194
+ 'alpha': self.lora_alpha,
195
+ 'target_names': self.target_names,
196
+ 'unlock_head_keywords': self.unlock_head_keywords
197
+ }
198
+
199
+ file_path = os.path.join(self.path, filename)
200
+
201
+ try:
202
+ if self.use_safetensors and filename.endswith('.safetensors'):
203
+ # Safetensors 模式
204
+ metadata = {
205
+ 'epoch': str(engine.epoch),
206
+ 'global_step': str(engine.global_step),
207
+ 'batch_idx': str(engine.batch_idx),
208
+ 'is_step': str(is_step),
209
+ 'orbit_lora_config': json.dumps(lora_config)
210
+ }
211
+ try:
212
+ metadata['meta'] = json.dumps(engine.meta)
213
+ except: pass
214
+
215
+ safe_save_file(lora_state_dict, file_path, metadata=metadata)
216
+
217
+ else:
218
+ # Torch 模式
219
+ state = {
220
+ 'epoch': engine.epoch,
221
+ 'global_step': engine.global_step,
222
+ 'batch_idx': engine.batch_idx,
223
+ 'is_step': is_step,
224
+ 'model_state_dict': lora_state_dict,
225
+ 'optimizer_state_dict': engine.optimizer.state_dict() if engine.optimizer else None,
226
+ 'scheduler_state_dict': engine.scheduler.state_dict() if engine.scheduler else None,
227
+ 'scaler_state_dict': engine.scaler.state_dict() if engine.scaler else None,
228
+ 'meta': engine.meta,
229
+ 'orbit_lora_config': lora_config
230
+ }
231
+ torch.save(state, file_path)
232
+
233
+ if verbose:
234
+ rel_path = os.path.relpath(file_path)
235
+ file_size = os.path.getsize(file_path) / 1024 / 1024
236
+ engine.print(f"Saved LoRA checkpoint: {rel_path} ({file_size:.2f} MB)", plugin='LoRA')
237
+ except Exception as e:
238
+ engine.print(f"[red]Failed to save checkpoint: {e}[/]", plugin='LoRA')
239
+
240
+ def _load(self, engine: 'Engine', file_path: str):
241
+ """重写加载逻辑,支持 strict=False 加载。"""
242
+ engine.print(f"[cyan]Loading LoRA checkpoint from: {file_path}[/]", plugin='LoRA')
243
+ try:
244
+ raw_model = engine.unwrap_model()
245
+ model_sd = None
246
+ saved_config = {}
247
+
248
+ if file_path.endswith('.safetensors'):
249
+ model_sd = safe_load_file(file_path, device=str(engine.device))
250
+
251
+ with safe_open(file_path, framework="pt", device=str(engine.device)) as f:
252
+ metadata = f.metadata()
253
+ if metadata and 'orbit_lora_config' in metadata:
254
+ try:
255
+ saved_config = json.loads(metadata['orbit_lora_config'])
256
+ except: pass
257
+ else:
258
+ checkpoint = torch.load(file_path, map_location=engine.device)
259
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
260
+ model_sd = checkpoint['model_state_dict']
261
+ saved_config = checkpoint.get('orbit_lora_config', {})
262
+ else:
263
+ model_sd = checkpoint if not isinstance(checkpoint, dict) else checkpoint
264
+
265
+ # 1. 加载模型参数
266
+ if model_sd is not None:
267
+ # 配置检查
268
+ if saved_config:
269
+ if saved_config.get('r') != self.r:
270
+ engine.print(f"[yellow]Warning: Loaded LoRA rank ({saved_config.get('r')}) != Current rank ({self.r})[/]", plugin='LoRA')
271
+
272
+ missing, unexpected = raw_model.load_state_dict(model_sd, strict=False)
273
+
274
+ # 过滤掉骨干网络的缺失警告,只关注 LoRA 部分
275
+ relevant_missing = [k for k in missing if 'lora_' in k or 'dora_' in k or any(h in k for h in (self.unlock_head_keywords or []))]
276
+ if relevant_missing:
277
+ engine.print(f"[yellow]Warning: Missing relevant keys: {relevant_missing}[/]", plugin='LoRA')
278
+ else:
279
+ engine.print("[green]LoRA weights loaded successfully.[/]", plugin='LoRA')
280
+
281
+ # 2. 恢复训练状态
282
+ if file_path.endswith('.safetensors'):
283
+ # Safetensors 仅恢复元数据
284
+ with safe_open(file_path, framework="pt", device=str(engine.device)) as f:
285
+ metadata = f.metadata()
286
+ if metadata:
287
+ loaded_epoch = int(metadata.get('epoch', 0))
288
+ loaded_batch_idx = int(metadata.get('batch_idx', -1))
289
+ is_step = metadata.get('is_step', 'False') == 'True'
290
+ engine.global_step = int(metadata.get('global_step', 0))
291
+
292
+ if 'meta' in metadata:
293
+ try:
294
+ engine.meta.update(json.loads(metadata['meta']))
295
+ except: pass
296
+
297
+ if is_step:
298
+ engine.start_epoch = loaded_epoch
299
+ engine.start_batch_idx = loaded_batch_idx
300
+ msg = f"Epoch {engine.start_epoch}, Batch {engine.start_batch_idx + 1}"
301
+ else:
302
+ engine.start_epoch = loaded_epoch + 1
303
+ engine.start_batch_idx = -1
304
+ msg = f"Epoch {engine.start_epoch}"
305
+
306
+ engine.print(f"[green]Resumed training from {msg}. Note: Optimizer state not restored (safetensors).[/]", plugin='LoRA')
307
+ else:
308
+ # Torch 恢复完整状态
309
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
310
+ if engine.optimizer and 'optimizer_state_dict' in checkpoint:
311
+ try:
312
+ engine.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
313
+ except Exception as e:
314
+ engine.print(f"[yellow]Warning: Failed to load optimizer state: {e}.[/]", plugin='LoRA')
315
+
316
+ if engine.scheduler and 'scheduler_state_dict' in checkpoint:
317
+ try:
318
+ engine.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
319
+ except: pass
320
+
321
+ if engine.scaler and 'scaler_state_dict' in checkpoint:
322
+ engine.scaler.load_state_dict(checkpoint['scaler_state_dict'])
323
+
324
+ if 'meta' in checkpoint:
325
+ engine.meta.update(checkpoint['meta'])
326
+
327
+ loaded_epoch = checkpoint.get('epoch', 0)
328
+ loaded_batch_idx = checkpoint.get('batch_idx', -1)
329
+ is_step = checkpoint.get('is_step', False)
330
+
331
+ if is_step:
332
+ engine.start_epoch = loaded_epoch
333
+ engine.start_batch_idx = loaded_batch_idx
334
+ msg = f"Epoch {engine.start_epoch}, Batch {engine.start_batch_idx + 1}"
335
+ else:
336
+ engine.start_epoch = loaded_epoch + 1
337
+ engine.start_batch_idx = -1
338
+ msg = f"Epoch {engine.start_epoch}"
339
+
340
+ engine.global_step = checkpoint.get('global_step', 0)
341
+ engine.print(f"[green]Resuming training from {msg}[/]", plugin='LoRA')
342
+
343
+ except Exception as e:
344
+ engine.print(f"[red]Failed to load checkpoint: {e}[/]", plugin='LoRA')
345
+ import traceback
346
+ traceback.print_exc()
@@ -8,8 +8,7 @@ from typing import TYPE_CHECKING, Optional, Union
8
8
 
9
9
  from orbit.callback import Callback, Event
10
10
 
11
- if TYPE_CHECKING:
12
- from orbit.engine import Engine
11
+ if TYPE_CHECKING: from orbit.engine import Engine
13
12
 
14
13
  class MemoryEstimator(Callback):
15
14
  """
orbit/plugin/warmup.py CHANGED
@@ -1,8 +1,7 @@
1
1
  from typing import Optional, List, TYPE_CHECKING
2
2
  from orbit.callback import Callback, Event
3
3
 
4
- if TYPE_CHECKING:
5
- from orbit.engine import Engine
4
+ if TYPE_CHECKING: from orbit.engine import Engine
6
5
 
7
6
  class Warmup(Callback):
8
7
  """
orbit/utils/__init__.py CHANGED
@@ -14,7 +14,7 @@ from .freeze import (
14
14
  set_trainable,
15
15
  freeze_layers,
16
16
  unfreeze_layers,
17
- get_trainable_params
17
+ count_params, ParamStats,
18
18
  )
19
19
  from .seed import (
20
20
  seed_everything,
@@ -27,3 +27,26 @@ from .mask import (
27
27
  make_causal_mask,
28
28
  make_sliding_window_mask
29
29
  )
30
+ from .layer_io import (
31
+ get_model_by_name,
32
+ save_layer, load_layer,
33
+ save_model, load_model
34
+ )
35
+ from .lora import (
36
+ save_lora, load_lora, inject_lora, inject_lora_file,
37
+ merge_lora, unmerge_lora,
38
+ freeze_backbone_only,
39
+ LoRADiagnoser
40
+ )
41
+ from .sft import (
42
+ build_sft, train_sft
43
+ )
44
+ from .cuda import (
45
+ cuda_alloc
46
+ )
47
+ from .image import (
48
+ split_to_patches, reconstruct_from_patches, pad_to_patch_size
49
+ )
50
+ from .moe import (
51
+ set_moe_training_mode
52
+ )
orbit/utils/cuda.py ADDED
@@ -0,0 +1,10 @@
1
+ import os
2
+
3
+ def cuda_alloc(size: int = 64):
4
+ '''
5
+ 设置 PyTorch CUDA 内存分配配置
6
+
7
+ Args:
8
+ size (int): 最大分割大小(MB)
9
+ '''
10
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = f'max_split_size_mb:{size},expandable_segments:True'
orbit/utils/freeze.py CHANGED
@@ -1,10 +1,20 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
- from typing import Union, List, Optional, Iterable
3
+ from typing import Union, List, Optional, Iterable, Type
4
+ from dataclasses import dataclass
5
+
6
+ @dataclass
7
+ class ParamStats:
8
+ count: int
9
+ params: List[torch.Tensor]
10
+
11
+ def __iter__(self):
12
+ return iter(self.params)
4
13
 
5
14
  def set_trainable(
6
15
  model: nn.Module,
7
16
  targets: Optional[Union[str, List[str]]] = None,
17
+ target_classes: Optional[Union[Type[nn.Module], List[Type[nn.Module]]]] = None,
8
18
  trainable: bool = False
9
19
  ) -> None:
10
20
  '''设置模型参数的 requires_grad 属性,用于冻结或解冻层。
@@ -12,48 +22,82 @@ def set_trainable(
12
22
  Args:
13
23
  model (nn.Module): 目标模型。
14
24
  targets (str or List[str], optional): 要操作的层名称或参数名称模式。
15
- - 如果为 None,则操作模型的所有参数。
25
+ - 如果为 None 且 target_classes 也为 None,则操作模型的所有参数。
16
26
  - 如果为 str,则操作名称中包含该字符串的所有参数。
17
27
  - 如果为 List[str],则操作名称中包含列表中任意字符串的所有参数。
28
+ target_classes (Type[nn.Module] or List[Type[nn.Module]], optional): 要操作的模块类。
29
+ - 如果指定,则操作属于该类(或其子类)的所有模块的参数。
18
30
  trainable (bool): 是否可训练 (True 为解冻, False 为冻结)。
19
31
  '''
20
- if targets is None:
32
+ if targets is None and target_classes is None:
21
33
  for param in model.parameters():
22
34
  param.requires_grad = trainable
23
- else:
35
+ return
36
+
37
+ if targets is not None:
24
38
  if isinstance(targets, str):
25
39
  targets = [targets]
26
40
 
27
41
  for name, param in model.named_parameters():
28
- # 检查参数名是否包含 targets 中的任何一个模式
29
- if any(t in name for t in targets):
30
- param.requires_grad = trainable
42
+ if any(t in name for t in targets): param.requires_grad = trainable
43
+
44
+ if target_classes is not None:
45
+ if not isinstance(target_classes, (list, tuple)): target_classes = [target_classes]
46
+
47
+ target_classes = tuple(target_classes)
31
48
 
32
- def freeze_layers(model: nn.Module, targets: Optional[Union[str, List[str]]] = None) -> None:
49
+ for module in model.modules():
50
+ if isinstance(module, target_classes):
51
+ for param in module.parameters(): param.requires_grad = trainable
52
+
53
+ def freeze_layers(
54
+ model: nn.Module,
55
+ targets: Optional[Union[str, List[str]]] = None,
56
+ target_classes: Optional[Union[Type[nn.Module], List[Type[nn.Module]]]] = None
57
+ ) -> None:
33
58
  '''冻结模型指定层或所有层 (requires_grad=False)。
34
59
 
35
60
  Args:
36
61
  model (nn.Module): 目标模型。
37
- targets (str or List[str], optional): 要冻结的层名称模式。如果不指定,则冻结整个模型。
62
+ targets (str or List[str], optional): 要冻结的层名称模式。
63
+ target_classes (Type[nn.Module] or List[Type[nn.Module]], optional): 要冻结的模块类。
64
+ 如果不指定 targets 和 target_classes,则冻结整个模型。
38
65
  '''
39
- set_trainable(model, targets, trainable=False)
66
+ set_trainable(model, targets, target_classes, trainable=False)
40
67
 
41
- def unfreeze_layers(model: nn.Module, targets: Optional[Union[str, List[str]]] = None) -> None:
68
+ def unfreeze_layers(
69
+ model: nn.Module,
70
+ targets: Optional[Union[str, List[str]]] = None,
71
+ target_classes: Optional[Union[Type[nn.Module], List[Type[nn.Module]]]] = None
72
+ ) -> None:
42
73
  '''解冻模型指定层或所有层 (requires_grad=True)。
43
74
 
44
75
  Args:
45
76
  model (nn.Module): 目标模型。
46
- targets (str or List[str], optional): 要解冻的层名称模式。如果不指定,则解冻整个模型。
77
+ targets (str or List[str], optional): 要解冻的层名称模式。
78
+ target_classes (Type[nn.Module] or List[Type[nn.Module]], optional): 要解冻的模块类。
79
+ 如果不指定 targets 和 target_classes,则解冻整个模型。
47
80
  '''
48
- set_trainable(model, targets, trainable=True)
81
+ set_trainable(model, targets, target_classes, trainable=True)
49
82
 
50
- def get_trainable_params(model: nn.Module) -> Iterable[torch.Tensor]:
51
- '''获取模型中 requires_grad=True 的参数,供优化器使用。
83
+ def count_params(model: nn.Module, mode: str = 'trainable') -> ParamStats:
84
+ '''统计模型参数数量并获取参数列表。
52
85
 
53
86
  Args:
54
87
  model (nn.Module): 目标模型。
88
+ mode (str): 统计模式,可选 'trainable', 'frozen', 'all'。默认为 'trainable'。
55
89
 
56
90
  Returns:
57
- Iterable[torch.Tensor]: 可训练参数的迭代器。
91
+ ParamStats: 包含参数总数(count)和参数列表(params)的数据类。
58
92
  '''
59
- return filter(lambda p: p.requires_grad, model.parameters())
93
+ if mode == 'trainable':
94
+ params = [p for p in model.parameters() if p.requires_grad]
95
+ elif mode == 'frozen':
96
+ params = [p for p in model.parameters() if not p.requires_grad]
97
+ elif mode == 'all':
98
+ params = list(model.parameters())
99
+ else:
100
+ raise ValueError(f"Invalid mode: {mode}. Must be one of 'trainable', 'frozen', 'all'.")
101
+
102
+ count = sum(p.numel() for p in params)
103
+ return ParamStats(count=count, params=params)